// Package admin provides HTTP client functionality for communicating
// with the Management API from the Admin UI server.
package admin
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// context keys used to forward browser metadata from Admin UI → Management API
type forwardedCtxKey string
const (
ctxKeyForwardedUA forwardedCtxKey = "forwarded_user_agent"
ctxKeyForwardedReferer forwardedCtxKey = "forwarded_referer"
ctxKeyForwardedIP forwardedCtxKey = "forwarded_ip"
)
// APIClient handles communication with the Management API
type APIClient struct {
baseURL string
token string
httpClient *http.Client
}
// NewAPIClient creates a new Management API client
func NewAPIClient(baseURL, token string) *APIClient {
return &APIClient{
baseURL: baseURL,
token: token,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// ObfuscateAPIKey obfuscates an API key for display purposes
// Shows first 8 characters followed by dots and last 4 characters
func ObfuscateAPIKey(apiKey string) string { return obfuscate.ObfuscateTokenGeneric(apiKey) }
// ObfuscateToken obfuscates a token for display purposes.
// Use centralized helper to avoid linter warning on deprecated wrapper.
func ObfuscateToken(token string) string { return obfuscate.ObfuscateTokenGeneric(token) }
// Project represents a project from the Management API
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
OpenAIAPIKey string `json:"openai_api_key"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Token represents a token from the Management API (sanitized)
type Token struct {
TokenID string `json:"token_id"` // Added for Admin UI support
ProjectID string `json:"project_id"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IsActive bool `json:"is_active"`
RequestCount int `json:"request_count"`
MaxRequests *int `json:"max_requests,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
CacheHitCount int `json:"cache_hit_count"`
}
// TokenCreateResponse represents the response when creating a token
type TokenCreateResponse struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
}
// Pagination represents pagination metadata
type Pagination struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalItems int `json:"total_items"`
TotalPages int `json:"total_pages"`
HasNext bool `json:"has_next"`
HasPrev bool `json:"has_prev"`
}
// DashboardData represents dashboard statistics
type DashboardData struct {
TotalProjects int `json:"total_projects"`
TotalTokens int `json:"total_tokens"`
ActiveTokens int `json:"active_tokens"`
ExpiredTokens int `json:"expired_tokens"`
TotalRequests int `json:"total_requests"`
RequestsToday int `json:"requests_today"`
RequestsThisWeek int `json:"requests_this_week"`
}
// AuditEvent represents an audit log entry for the admin UI
type AuditEvent struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"`
Actor string `json:"actor"`
ProjectID *string `json:"project_id,omitempty"`
RequestID *string `json:"request_id,omitempty"`
CorrelationID *string `json:"correlation_id,omitempty"`
ClientIP *string `json:"client_ip,omitempty"`
Method *string `json:"method,omitempty"`
Path *string `json:"path,omitempty"`
UserAgent *string `json:"user_agent,omitempty"`
Outcome string `json:"outcome"`
Reason *string `json:"reason,omitempty"`
TokenID *string `json:"token_id,omitempty"`
Metadata *string `json:"metadata,omitempty"` // JSON string
ParsedMeta *AuditMetadata `json:"-"` // Parsed metadata for template use
}
// AuditMetadata represents parsed metadata for easier template access
type AuditMetadata map[string]interface{}
// AuditEventsResponse represents the API response for audit events listing
type AuditEventsResponse struct {
Events []AuditEvent `json:"events"`
Pagination Pagination `json:"pagination"`
}
// GetDashboardData retrieves dashboard statistics
func (c *APIClient) GetDashboardData(ctx context.Context) (*DashboardData, error) {
// For now, calculate from projects and tokens lists
// In the future, this could be a dedicated dashboard endpoint
projects, _, err := c.GetProjects(ctx, 1, 1000) // Get all projects
if err != nil {
return nil, err
}
tokens, _, err := c.GetTokens(ctx, "", 1, 1000) // Get all tokens
if err != nil {
return nil, err
}
data := &DashboardData{
TotalProjects: len(projects),
TotalTokens: len(tokens),
}
// Calculate active/expired tokens and request counts
now := time.Now()
for _, token := range tokens {
if token.IsActive && token.ExpiresAt != nil && token.ExpiresAt.After(now) {
data.ActiveTokens++
} else {
data.ExpiredTokens++
}
data.TotalRequests += token.RequestCount
// Calculate today's requests (approximation)
if token.LastUsedAt != nil && token.LastUsedAt.After(now.AddDate(0, 0, -1)) {
data.RequestsToday += token.RequestCount
}
// Calculate this week's requests (approximation)
if token.LastUsedAt != nil && token.LastUsedAt.After(now.AddDate(0, 0, -7)) {
data.RequestsThisWeek += token.RequestCount
}
}
return data, nil
}
// GetProjects retrieves a paginated list of projects
func (c *APIClient) GetProjects(ctx context.Context, page, pageSize int) ([]Project, *Pagination, error) {
// Since the Management API doesn't currently support pagination,
// we'll get all projects and simulate pagination
req, err := c.newRequest(ctx, "GET", "/manage/projects", nil)
if err != nil {
return nil, nil, err
}
var projects []Project
if err := c.doRequest(req, &projects); err != nil {
return nil, nil, err
}
// Simulate pagination
totalItems := len(projects)
totalPages := (totalItems + pageSize - 1) / pageSize
start := (page - 1) * pageSize
end := start + pageSize
if start >= totalItems {
projects = []Project{}
} else {
if end > totalItems {
end = totalItems
}
projects = projects[start:end]
}
pagination := &Pagination{
Page: page,
PageSize: pageSize,
TotalItems: totalItems,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
return projects, pagination, nil
}
// GetProject retrieves a single project by ID
func (c *APIClient) GetProject(ctx context.Context, id string) (*Project, error) {
req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/manage/projects/%s", id), nil)
if err != nil {
return nil, err
}
var project Project
if err := c.doRequest(req, &project); err != nil {
return nil, err
}
return &project, nil
}
// CreateProject creates a new project
func (c *APIClient) CreateProject(ctx context.Context, name, openaiAPIKey string) (*Project, error) {
payload := map[string]string{
"name": name,
"openai_api_key": openaiAPIKey,
}
req, err := c.newRequest(ctx, "POST", "/manage/projects", payload)
if err != nil {
return nil, err
}
var project Project
if err := c.doRequest(req, &project); err != nil {
return nil, err
}
return &project, nil
}
// UpdateProject updates an existing project
func (c *APIClient) UpdateProject(ctx context.Context, id, name, openaiAPIKey string, isActive *bool) (*Project, error) {
payload := map[string]interface{}{}
if name != "" {
payload["name"] = name
}
if openaiAPIKey != "" {
payload["openai_api_key"] = openaiAPIKey
}
if isActive != nil {
payload["is_active"] = *isActive
}
req, err := c.newRequest(ctx, "PATCH", fmt.Sprintf("/manage/projects/%s", id), payload)
if err != nil {
return nil, err
}
var project Project
if err := c.doRequest(req, &project); err != nil {
return nil, err
}
return &project, nil
}
// DeleteProject deletes a project
func (c *APIClient) DeleteProject(ctx context.Context, id string) error {
req, err := c.newRequest(ctx, "DELETE", fmt.Sprintf("/manage/projects/%s", id), nil)
if err != nil {
return err
}
return c.doRequest(req, nil)
}
// GetTokens retrieves a paginated list of tokens
func (c *APIClient) GetTokens(ctx context.Context, projectID string, page, pageSize int) ([]Token, *Pagination, error) {
path := "/manage/tokens"
if projectID != "" {
path += "?projectId=" + url.QueryEscape(projectID)
}
req, err := c.newRequest(ctx, "GET", path, nil)
if err != nil {
return nil, nil, err
}
var tokens []Token
if err := c.doRequest(req, &tokens); err != nil {
return nil, nil, err
}
// Simulate pagination (similar to projects)
totalItems := len(tokens)
totalPages := (totalItems + pageSize - 1) / pageSize
start := (page - 1) * pageSize
end := start + pageSize
if start >= totalItems {
tokens = []Token{}
} else {
if end > totalItems {
end = totalItems
}
tokens = tokens[start:end]
}
pagination := &Pagination{
Page: page,
PageSize: pageSize,
TotalItems: totalItems,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
return tokens, pagination, nil
}
// CreateToken creates a new token for a project with a given duration in minutes
func (c *APIClient) CreateToken(ctx context.Context, projectID string, durationMinutes int) (*TokenCreateResponse, error) {
payload := map[string]interface{}{
"project_id": projectID,
"duration_minutes": durationMinutes,
}
// Use newRequest and doRequest for consistent error handling
req, err := c.newRequest(ctx, "POST", "/manage/tokens", payload)
if err != nil {
return nil, err
}
var result TokenCreateResponse
if err := c.doRequest(req, &result); err != nil {
return nil, err
}
return &result, nil
}
// newRequest creates a new HTTP request with authentication
func (c *APIClient) newRequest(ctx context.Context, method, path string, body any) (*http.Request, error) {
var reqBody []byte
var err error
if body != nil {
reqBody, err = json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
}
url := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Authorization", "Bearer "+c.token)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
// Forward browser context when present
if v := ctx.Value(ctxKeyForwardedUA); v != nil {
if ua, ok := v.(string); ok && ua != "" {
req.Header.Set("X-Forwarded-User-Agent", ua)
req.Header.Set("X-Admin-Origin", "1")
}
}
if v := ctx.Value(ctxKeyForwardedReferer); v != nil {
if ref, ok := v.(string); ok && ref != "" {
req.Header.Set("X-Forwarded-Referer", ref)
req.Header.Set("X-Admin-Origin", "1")
}
}
if v := ctx.Value(ctxKeyForwardedIP); v != nil {
if ip, ok := v.(string); ok && ip != "" {
// Provide original browser IP for backend audit logging
req.Header.Set("X-Forwarded-For", ip)
req.Header.Set("X-Admin-Origin", "1")
}
}
return req, nil
}
// doRequest executes an HTTP request and handles the response
func (c *APIClient) doRequest(req *http.Request, result any) error {
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer func() {
err := resp.Body.Close()
if err != nil {
// Log or handle the error as appropriate
// For now, just log to standard error
fmt.Fprintf(os.Stderr, "failed to close response body: %v\n", err)
}
}()
if resp.StatusCode >= 400 {
var errorResp map[string]any
if err := json.NewDecoder(resp.Body).Decode(&errorResp); err == nil {
if msg, ok := errorResp["error"].(string); ok {
return fmt.Errorf("API error (%d): %s", resp.StatusCode, msg)
}
}
return fmt.Errorf("API error: %d %s", resp.StatusCode, resp.Status)
}
if result != nil && resp.StatusCode != http.StatusNoContent {
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
}
return nil
}
// GetAuditEvents retrieves a paginated list of audit events with optional filters
func (c *APIClient) GetAuditEvents(ctx context.Context, filters map[string]string, page, pageSize int) ([]AuditEvent, *Pagination, error) {
// Build query parameters
params := url.Values{}
for key, value := range filters {
if value != "" {
params.Set(key, value)
}
}
params.Set("page", fmt.Sprintf("%d", page))
params.Set("page_size", fmt.Sprintf("%d", pageSize))
endpoint := "/manage/audit"
if len(params) > 0 {
endpoint += "?" + params.Encode()
}
req, err := c.newRequest(ctx, "GET", endpoint, nil)
if err != nil {
return nil, nil, err
}
var response AuditEventsResponse
if err := c.doRequest(req, &response); err != nil {
return nil, nil, err
}
// Parse metadata for each event
for i := range response.Events {
if response.Events[i].Metadata != nil && *response.Events[i].Metadata != "" {
var meta AuditMetadata
if err := json.Unmarshal([]byte(*response.Events[i].Metadata), &meta); err == nil {
response.Events[i].ParsedMeta = &meta
}
}
}
return response.Events, &response.Pagination, nil
}
// GetAuditEvent retrieves a specific audit event by ID
func (c *APIClient) GetAuditEvent(ctx context.Context, id string) (*AuditEvent, error) {
req, err := c.newRequest(ctx, "GET", "/manage/audit/"+id, nil)
if err != nil {
return nil, err
}
var event AuditEvent
if err := c.doRequest(req, &event); err != nil {
return nil, err
}
// Parse metadata
if event.Metadata != nil && *event.Metadata != "" {
var meta AuditMetadata
if err := json.Unmarshal([]byte(*event.Metadata), &meta); err == nil {
event.ParsedMeta = &meta
}
}
return &event, nil
}
// GetToken retrieves a single token by ID
func (c *APIClient) GetToken(ctx context.Context, tokenID string) (*Token, error) {
req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/manage/tokens/%s", tokenID), nil)
if err != nil {
return nil, err
}
var token Token
if err := c.doRequest(req, &token); err != nil {
return nil, err
}
return &token, nil
}
// UpdateToken updates an existing token
func (c *APIClient) UpdateToken(ctx context.Context, tokenID string, isActive *bool, maxRequests *int) (*Token, error) {
payload := map[string]interface{}{}
if isActive != nil {
payload["is_active"] = *isActive
}
if maxRequests != nil {
payload["max_requests"] = *maxRequests
}
req, err := c.newRequest(ctx, "PATCH", fmt.Sprintf("/manage/tokens/%s", tokenID), payload)
if err != nil {
return nil, err
}
var token Token
if err := c.doRequest(req, &token); err != nil {
return nil, err
}
return &token, nil
}
// RevokeToken revokes a single token by setting is_active to false
func (c *APIClient) RevokeToken(ctx context.Context, tokenID string) error {
req, err := c.newRequest(ctx, "DELETE", fmt.Sprintf("/manage/tokens/%s", tokenID), nil)
if err != nil {
return err
}
return c.doRequest(req, nil)
}
// RevokeProjectTokens revokes all tokens for a project in bulk
func (c *APIClient) RevokeProjectTokens(ctx context.Context, projectID string) error {
req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/manage/projects/%s/tokens/revoke", projectID), nil)
if err != nil {
return err
}
return c.doRequest(req, nil)
}
// Package admin provides the HTTP server for the Admin UI.
// This package implements a separate web interface for managing
// projects and tokens via the Management API.
package admin
import (
"context"
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/sofatutor/llm-proxy/internal/audit"
"github.com/sofatutor/llm-proxy/internal/config"
"github.com/sofatutor/llm-proxy/internal/logging"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"go.uber.org/zap"
)
// Session represents a user session
type Session struct {
ID string
Token string
CreatedAt time.Time
ExpiresAt time.Time
RememberMe bool
}
// getSessionSecret derives a secret key from the management token and a salt
func getSessionSecret(cfg *config.Config) []byte {
salt := "llmproxy-cookie-salt" // TODO: move to config/env
return []byte(cfg.AdminUI.ManagementToken + salt)
}
// APIClientInterface abstracts the API client for testability
//go:generate mockgen -destination=mock_api_client.go -package=admin . APIClientInterface
// Only the methods needed by handlers are included
type APIClientInterface interface {
GetDashboardData(ctx context.Context) (*DashboardData, error)
GetProjects(ctx context.Context, page, pageSize int) ([]Project, *Pagination, error)
GetTokens(ctx context.Context, projectID string, page, pageSize int) ([]Token, *Pagination, error)
CreateToken(ctx context.Context, projectID string, durationMinutes int) (*TokenCreateResponse, error)
GetProject(ctx context.Context, projectID string) (*Project, error)
UpdateProject(ctx context.Context, projectID string, name string, openAIAPIKey string, isActive *bool) (*Project, error)
DeleteProject(ctx context.Context, projectID string) error
CreateProject(ctx context.Context, name string, openAIAPIKey string) (*Project, error)
GetAuditEvents(ctx context.Context, filters map[string]string, page, pageSize int) ([]AuditEvent, *Pagination, error)
GetAuditEvent(ctx context.Context, id string) (*AuditEvent, error)
GetToken(ctx context.Context, tokenID string) (*Token, error)
UpdateToken(ctx context.Context, tokenID string, isActive *bool, maxRequests *int) (*Token, error)
RevokeToken(ctx context.Context, tokenID string) error
RevokeProjectTokens(ctx context.Context, projectID string) error
}
// Server represents the Admin UI HTTP server.
// It provides a web interface for managing projects and tokens
// by communicating with the Management API.
type Server struct {
server *http.Server
config *config.Config
engine *gin.Engine
apiClient *APIClient
logger *zap.Logger
// For testability: allow injection of token validation logic
ValidateTokenWithAPI func(context.Context, string) bool
// Audit logger for admin actions
auditLogger *audit.Logger
}
// NewServer creates a new Admin UI server with the provided configuration.
// It initializes the Gin engine, sets up routes, and configures the HTTP server.
func NewServer(cfg *config.Config) (*Server, error) {
// Initialize logger
logger, err := logging.NewLogger(cfg.LogLevel, cfg.LogFormat, cfg.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize admin server logger: %w", err)
}
// Set Gin mode based on log level
if cfg.LogLevel == "debug" {
gin.SetMode(gin.DebugMode)
} else {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New()
// Add middleware
engine.Use(gin.Logger())
engine.Use(gin.Recovery())
// Add method override middleware for HTML forms
engine.Use(func(c *gin.Context) {
if c.Request.Method == "POST" && c.Request.FormValue("_method") != "" {
c.Request.Method = c.Request.FormValue("_method")
}
c.Next()
})
// Add session middleware
store := cookie.NewStore(getSessionSecret(cfg))
engine.Use(sessions.Sessions("llmproxy_session", store))
// Create API client for communicating with Management API
// Note: API client will be updated when user logs in
var apiClient *APIClient
if cfg.AdminUI.ManagementToken != "" {
apiClient = NewAPIClient(cfg.AdminUI.APIBaseURL, cfg.AdminUI.ManagementToken)
}
// Initialize audit logger
var auditLogger *audit.Logger
if cfg.AuditEnabled && cfg.AuditLogFile != "" {
auditConfig := audit.LoggerConfig{
FilePath: cfg.AuditLogFile,
CreateDir: cfg.AuditCreateDir,
}
var err error
auditLogger, err = audit.NewLogger(auditConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize admin audit logger: %w", err)
}
} else {
auditLogger = audit.NewNullLogger()
}
s := &Server{
config: cfg,
engine: engine,
apiClient: apiClient,
logger: logger,
auditLogger: auditLogger,
server: &http.Server{
Addr: cfg.AdminUI.ListenAddr,
Handler: engine,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
},
}
// Setup routes
s.setupRoutes()
return s, nil
}
// Start starts the Admin UI server.
// This method blocks until the server is shut down or an error occurs.
func (s *Server) Start() error {
return s.server.ListenAndServe()
}
// Shutdown gracefully shuts down the server without interrupting active connections.
func (s *Server) Shutdown(ctx context.Context) error {
if s.server == nil {
return nil
}
return s.server.Shutdown(ctx)
}
// setupRoutes configures all the routes for the Admin UI
func (s *Server) setupRoutes() {
// Serve static files (CSS, JS, images)
s.engine.Static("/static", "./web/static")
// Load HTML templates with custom functions using glob patterns
s.engine.SetFuncMap(s.templateFuncs())
td := s.config.AdminUI.TemplateDir
// Load all templates - both root level and subdirectories
templGlob := template.Must(template.New("").Funcs(s.templateFuncs()).ParseGlob(filepath.Join(td, "*.html")))
templGlob = template.Must(templGlob.ParseGlob(filepath.Join(td, "*/*.html")))
s.engine.SetHTMLTemplate(templGlob)
// Authentication routes (no middleware)
auth := s.engine.Group("/auth")
{
auth.GET("/login", s.handleLoginForm)
auth.POST("/login", s.handleLogin)
auth.GET("/logout", s.handleLogout) // Allow GET for direct URL access
auth.POST("/logout", s.handleLogout) // Keep POST for form submission
}
// Root route - redirect to dashboard
s.engine.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/dashboard")
})
// Protected routes with authentication middleware
protected := s.engine.Group("/")
protected.Use(s.authMiddleware())
{
// Dashboard
protected.GET("/dashboard", s.handleDashboard)
// Projects routes
projects := protected.Group("/projects")
{
projects.GET("", s.handleProjectsList)
projects.GET("/new", s.handleProjectsNew)
projects.POST("", s.handleProjectsCreate)
projects.GET("/:id", s.handleProjectsShow)
projects.GET("/:id/edit", s.handleProjectsEdit)
// HTML forms submit via POST with _method override; Gin matches routes before middleware,
// so provide POST fallback routes that dispatch to the correct handlers.
projects.POST("/:id", s.handleProjectsPostOverride)
projects.PUT("/:id", s.handleProjectsUpdate)
projects.DELETE("/:id", s.handleProjectsDelete)
projects.POST("/:id/revoke-tokens", s.handleProjectsBulkRevoke)
}
// Tokens routes
tokens := protected.Group("/tokens")
{
tokens.GET("", s.handleTokensList)
tokens.GET("/new", s.handleTokensNew)
tokens.POST("", s.handleTokensCreate)
tokens.GET("/:token", s.handleTokensShow)
tokens.GET("/:token/edit", s.handleTokensEdit)
// POST fallback for HTML forms with _method override
tokens.POST("/:token", s.handleTokensPostOverride)
tokens.PUT("/:token", s.handleTokensUpdate)
tokens.DELETE("/:token", s.handleTokensRevoke)
}
// Audit routes
audit := protected.Group("/audit")
{
audit.GET("", s.handleAuditList)
audit.GET("/:id", s.handleAuditShow)
}
}
// Health check
s.engine.GET("/health", func(c *gin.Context) {
adminHealth := gin.H{
"status": "ok",
"timestamp": time.Now(),
"service": "admin-ui",
"version": "0.1.0",
}
backendURL := s.config.AdminUI.APIBaseURL
if backendURL == "" {
backendURL = "http://localhost:8080"
}
if !strings.HasSuffix(backendURL, "/health") {
backendURL = strings.TrimRight(backendURL, "/") + "/health"
}
client := &http.Client{Timeout: 2 * time.Second}
backendHealth := gin.H{
"status": "down",
"error": "Backend unavailable",
}
if resp, err := client.Get(backendURL); err == nil && resp.StatusCode == 200 {
var backendData map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&backendData); err == nil {
backendHealth = backendData
}
err := resp.Body.Close()
if err != nil {
s.logger.Warn("failed to close backend health response body", zap.Error(err))
}
}
c.JSON(http.StatusOK, gin.H{
"admin": adminHealth,
"backend": backendHealth,
})
})
}
// Dashboard handlers
func (s *Server) handleDashboard(c *gin.Context) {
// Get API client from context
apiClientIface := c.MustGet("apiClient").(APIClientInterface)
// Get dashboard data from Management API with forwarded browser metadata
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
dashboardData, err := apiClientIface.GetDashboardData(ctx)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load dashboard data: %v", err),
})
return
}
c.HTML(http.StatusOK, "dashboard.html", gin.H{
"title": "Dashboard",
"active": "dashboard",
"data": dashboardData,
})
}
// Project handlers
func (s *Server) handleProjectsList(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
page := getPageFromQuery(c, 1)
pageSize := getPageSizeFromQuery(c, 10)
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
// Forward best-effort original client IP from headers
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
projects, pagination, err := apiClient.GetProjects(ctx, page, pageSize)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load projects: %v", err),
})
return
}
c.HTML(http.StatusOK, "projects/list.html", gin.H{
"title": "Projects",
"active": "projects",
"projects": projects,
"pagination": pagination,
})
}
func (s *Server) handleProjectsNew(c *gin.Context) {
c.HTML(http.StatusOK, "projects/new.html", gin.H{
"title": "Create Project",
"active": "projects",
})
}
func (s *Server) handleProjectsCreate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
var req struct {
Name string `form:"name" binding:"required"`
OpenAIAPIKey string `form:"openai_api_key" binding:"required"`
}
if err := c.ShouldBind(&req); err != nil {
c.HTML(http.StatusBadRequest, "projects/new.html", gin.H{
"title": "Create Project",
"active": "projects",
"error": "Please fill in all required fields",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.CreateProject(ctx, req.Name, req.OpenAIAPIKey)
if err != nil {
c.HTML(http.StatusInternalServerError, "projects/new.html", gin.H{
"title": "Create Project",
"active": "projects",
"error": fmt.Sprintf("Failed to create project: %v", err),
})
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/projects/%s", project.ID))
}
func (s *Server) handleProjectsShow(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.GetProject(ctx, id)
if err != nil {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Project not found",
})
return
}
c.HTML(http.StatusOK, "projects/show.html", gin.H{
"title": "Project Details",
"active": "projects",
"project": project,
})
}
func (s *Server) handleProjectsEdit(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.GetProject(ctx, id)
if err != nil {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Project not found",
})
return
}
c.HTML(http.StatusOK, "projects/edit.html", gin.H{
"title": "Edit Project",
"active": "projects",
"project": project,
})
}
func (s *Server) handleProjectsUpdate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
var req struct {
Name string `form:"name" binding:"required"`
OpenAIAPIKey string `form:"openai_api_key" binding:"required"`
IsActive *bool `form:"is_active"`
}
if err := c.ShouldBind(&req); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid form data",
})
return
}
// Checkbox handling via helper for clarity and reuse
isActive := parseBoolFormField(c, "is_active")
isActivePtr := &isActive
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.UpdateProject(ctx, id, req.Name, req.OpenAIAPIKey, isActivePtr)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to update project: %v", err),
})
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/projects/%s", project.ID))
}
// handleProjectsPostOverride routes POST requests with _method overrides to the appropriate handler.
// It ensures form submissions to /projects/:id work even though Gin resolves routes before middleware.
func (s *Server) handleProjectsPostOverride(c *gin.Context) {
// Parse form to access _method
if err := c.Request.ParseForm(); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Failed to parse form data",
})
return
}
method := c.PostForm("_method")
switch strings.ToUpper(method) {
case http.MethodPut:
s.handleProjectsUpdate(c)
return
case http.MethodDelete:
s.handleProjectsDelete(c)
return
default:
// No override provided; treat as bad request
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Unsupported method override for project action",
})
return
}
}
// parseBoolFormField returns the boolean value of a form field that may submit
// both a hidden "false" and a checkbox "true". The last value wins.
func parseBoolFormField(c *gin.Context, name string) bool {
vals := c.PostFormArray(name)
if len(vals) == 0 {
return false
}
return strings.EqualFold(vals[len(vals)-1], "true")
}
func (s *Server) handleProjectsDelete(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
err := apiClient.DeleteProject(ctx, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to delete project: %v", err),
})
return
}
c.Redirect(http.StatusSeeOther, "/projects")
}
// Token handlers
func (s *Server) handleTokensList(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
page := getPageFromQuery(c, 1)
pageSize := getPageSizeFromQuery(c, 10)
projectID := c.Query("project_id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
tokens, pagination, err := apiClient.GetTokens(ctx, projectID, page, pageSize)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load tokens: %v", err),
})
return
}
// Fetch all projects to create a lookup map for project names
projects, _, err := apiClient.GetProjects(ctx, 1, 1000) // Get up to 1000 projects
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load projects: %v", err),
})
return
}
// Create project ID to name lookup map
projectNames := make(map[string]string)
for _, project := range projects {
projectNames[project.ID] = project.Name
}
c.HTML(http.StatusOK, "tokens/list.html", gin.H{
"title": "Tokens",
"active": "tokens",
"tokens": tokens,
"pagination": pagination,
"projectId": projectID,
"projectNames": projectNames,
"now": time.Now(),
"currentTime": time.Now(),
})
}
func (s *Server) handleTokensNew(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
projects, _, err := apiClient.GetProjects(ctx, 1, 100)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load projects: %v", err),
})
return
}
c.HTML(http.StatusOK, "tokens/new.html", gin.H{
"title": "Generate Token",
"active": "tokens",
"projects": projects,
})
}
func (s *Server) handleTokensCreate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
var req struct {
ProjectID string `form:"project_id" binding:"required"`
DurationMinutes int `form:"duration_minutes" binding:"required,min=1,max=525600"`
}
if err := c.ShouldBind(&req); err != nil {
// forward context as well for consistency in audit logs
projCtx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
projCtx = context.WithValue(projCtx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
projCtx = context.WithValue(projCtx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
projCtx = context.WithValue(projCtx, ctxKeyForwardedReferer, ref)
}
projects, _, _ := apiClient.GetProjects(projCtx, 1, 100)
c.HTML(http.StatusBadRequest, "tokens/new.html", gin.H{
"title": "Generate Token",
"active": "tokens",
"projects": projects,
"error": "Please fill in all required fields correctly",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
token, err := apiClient.CreateToken(ctx, req.ProjectID, req.DurationMinutes)
if err != nil {
// forward context as well for consistency in audit logs
projCtx := ctx
projects, _, _ := apiClient.GetProjects(projCtx, 1, 100)
c.HTML(http.StatusInternalServerError, "tokens/new.html", gin.H{
"title": "Generate Token",
"active": "tokens",
"projects": projects,
"error": fmt.Sprintf("Failed to create token: %v", err),
})
return
}
c.HTML(http.StatusOK, "tokens/created.html", gin.H{
"title": "Token Created",
"active": "tokens",
"token": token,
})
}
func (s *Server) handleTokensShow(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
token, err := apiClient.GetToken(ctx, tokenID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load token",
"details": err.Error(),
})
}
return
}
// Get project for the token
project, err := apiClient.GetProject(ctx, token.ProjectID)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load project",
"details": err.Error(),
})
return
}
c.HTML(http.StatusOK, "tokens/show.html", gin.H{
"title": "Token Details",
"active": "tokens",
"token": token,
"project": project,
"tokenID": tokenID,
"now": time.Now(),
"currentTime": time.Now(),
})
}
// Helper functions
func getPageFromQuery(c *gin.Context, defaultValue int) int {
if page := c.Query("page"); page != "" {
if p, err := parsePositiveInt(page); err == nil && p > 0 {
return p
}
}
return defaultValue
}
func getPageSizeFromQuery(c *gin.Context, defaultValue int) int {
if size := c.Query("size"); size != "" {
if s, err := parsePositiveInt(size); err == nil && s > 0 && s <= 100 {
return s
}
}
return defaultValue
}
func parsePositiveInt(s string) (int, error) {
var result int
if _, err := fmt.Sscanf(s, "%d", &result); err != nil {
return 0, err
}
return result, nil
}
// templateFuncs returns custom template functions for HTML templates
func (s *Server) templateFuncs() template.FuncMap {
return template.FuncMap{
"stringOr": func(value any, fallback string) string {
// Safely dereference optional strings for templates
switch v := value.(type) {
case *string:
if v != nil && *v != "" {
return *v
}
case string:
if v != "" {
return v
}
}
return fallback
},
"add": func(a, b int) int {
return a + b
},
"sub": func(a, b int) int {
return a - b
},
"safeSub": func(a, b int) int {
// Returns max(0, a-b) to prevent negative values in templates
if a < b {
return 0
}
return a - b
},
"inc": func(a int) int {
return a + 1
},
"dec": func(a int) int {
return a - 1
},
"seq": func(start, end int) []int {
if start > end {
return []int{}
}
seq := make([]int, end-start+1)
for i := range seq {
seq[i] = start + i
}
return seq
},
"now": func() time.Time {
return time.Now()
},
"eq": func(a, b any) bool {
return a == b
},
"ne": func(a, b any) bool {
return a != b
},
"lt": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v < b2
}
case int64:
if b2, ok := b.(int64); ok {
return v < b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.Before(b2)
}
}
return false
},
"gt": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v > b2
}
case int64:
if b2, ok := b.(int64); ok {
return v > b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.After(b2)
}
}
return false
},
"le": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v <= b2
}
case int64:
if b2, ok := b.(int64); ok {
return v <= b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.Before(b2) || v.Equal(b2)
}
}
return false
},
"ge": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v >= b2
}
case int64:
if b2, ok := b.(int64); ok {
return v >= b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.After(b2) || v.Equal(b2)
}
}
return false
},
"and": func(a, b bool) bool {
return a && b
},
"or": func(a, b bool) bool {
return a || b
},
"not": func(a bool) bool {
return !a
},
"obfuscateAPIKey": func(apiKey string) string { return obfuscate.ObfuscateTokenGeneric(apiKey) },
"obfuscateToken": func(token string) string { return obfuscate.ObfuscateTokenSimple(token) },
"contains": func(s, substr string) bool {
return strings.Contains(s, substr)
},
"pageRange": func(current, total int) []int {
// Show up to 7 page numbers around current page
start := current - 3
end := current + 3
if start < 1 {
start = 1
}
if end > total {
end = total
}
// Adjust if we have fewer than 7 pages to show
if end-start < 6 && total > 6 {
if start == 1 {
end = start + 6
if end > total {
end = total
}
} else if end == total {
start = end - 6
if start < 1 {
start = 1
}
}
}
pages := make([]int, 0, end-start+1)
for i := start; i <= end; i++ {
pages = append(pages, i)
}
return pages
},
"dict": func(values ...interface{}) map[string]interface{} {
if len(values)%2 != 0 {
return nil
}
dict := make(map[string]interface{}, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].(string)
if !ok {
return nil
}
dict[key] = values[i+1]
}
return dict
},
}
}
// Authentication handlers
func (s *Server) handleLoginForm(c *gin.Context) {
c.HTML(http.StatusOK, "login.html", gin.H{
"title": "Sign In",
})
}
func (s *Server) handleLogin(c *gin.Context) {
logger := s.logger
if logger == nil {
logger = zap.NewNop()
}
var req struct {
ManagementToken string `form:"management_token" binding:"required"`
RememberMe bool `form:"remember_me"`
}
if err := c.Request.ParseForm(); err == nil {
// Only log field presence, not values
logger.Debug("raw POST form fields", zap.Strings("fields", getFormFieldNames(c.Request.PostForm)))
}
if err := c.ShouldBind(&req); err != nil {
logger.Warn("ShouldBind error", zap.Error(err))
c.HTML(http.StatusBadRequest, "login.html", gin.H{
"title": "Sign In",
"error": "Please enter your management token",
})
return
}
logger.Info("login attempt", zap.String("token", obfuscateToken(req.ManagementToken)), zap.Bool("remember_me", req.RememberMe))
// Use injected or default token validation
validate := s.ValidateTokenWithAPI
if validate == nil {
validate = s.validateTokenWithAPI
}
if !validate(c.Request.Context(), req.ManagementToken) {
logger.Error("token validation failed", zap.String("token", obfuscateToken(req.ManagementToken)))
c.HTML(http.StatusUnauthorized, "login.html", gin.H{
"title": "Sign In",
"error": "Invalid management token. Please check your token and try again.",
})
return
}
// Set session cookie
session := sessions.Default(c)
session.Set("token", req.ManagementToken)
if req.RememberMe {
session.Options(sessions.Options{
MaxAge: 30 * 24 * 60 * 60, // 30 days
Path: "/",
HttpOnly: true,
Secure: false,
})
} else {
session.Options(sessions.Options{
MaxAge: 0,
Path: "/",
HttpOnly: true,
Secure: false,
})
}
if err := session.Save(); err != nil {
logger.Error("session save error", zap.Error(err))
}
logger.Info("session saved", zap.String("token", obfuscateToken(req.ManagementToken)), zap.Bool("remember_me", req.RememberMe))
// Redirect to dashboard
c.Redirect(http.StatusSeeOther, "/dashboard")
}
func (s *Server) handleLogout(c *gin.Context) {
session := sessions.Default(c)
session.Clear()
if err := session.Save(); err != nil {
s.logger.Error("session save error", zap.Error(err))
}
c.Redirect(http.StatusSeeOther, "/auth/login")
}
// authMiddleware checks for valid session
func (s *Server) authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
token := session.Get("token")
if token == nil {
c.Redirect(http.StatusSeeOther, "/auth/login")
c.Abort()
return
}
// Optionally, validate token again or cache validation
apiClient := NewAPIClient(s.config.AdminUI.APIBaseURL, token.(string))
c.Set("apiClient", apiClient)
c.Next()
}
}
// validateTokenWithAPI validates a token against the Management API
func (s *Server) validateTokenWithAPI(ctx context.Context, token string) bool {
// Create a HTTP client to validate the token
client := &http.Client{Timeout: 10 * time.Second}
// Try to access a simple management endpoint with the token
req, err := http.NewRequestWithContext(ctx, "GET", s.config.AdminUI.APIBaseURL+"/manage/projects", nil)
if err != nil {
return false
}
// Add authorization header
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
// Make the request
resp, err := client.Do(req)
if err != nil {
return false
}
defer func() {
if err := resp.Body.Close(); err != nil {
s.logger.Warn("Error closing response body", zap.Error(err))
}
}()
// Valid token should return 200 OK (or other success status)
return resp.StatusCode >= 200 && resp.StatusCode < 300
}
// getFormFieldNames returns a slice of form field names for logging
func getFormFieldNames(form map[string][]string) []string {
fields := make([]string, 0, len(form))
for k := range form {
fields = append(fields, k)
}
return fields
}
// obfuscateToken returns a partially masked version of a token for logging
func obfuscateToken(token string) string {
return obfuscate.ObfuscateTokenSimple(token)
}
// Audit handlers
func (s *Server) handleAuditList(c *gin.Context) {
// Get API client from context
apiClientIface := c.MustGet("apiClient").(APIClientInterface)
// Parse query parameters for filtering
filters := make(map[string]string)
query := c.Request.URL.Query()
// Filter parameters
if action := query.Get("action"); action != "" {
filters["action"] = action
}
if outcome := query.Get("outcome"); outcome != "" {
filters["outcome"] = outcome
}
if projectID := query.Get("project_id"); projectID != "" {
filters["project_id"] = projectID
}
if actor := query.Get("actor"); actor != "" {
filters["actor"] = actor
}
if clientIP := query.Get("client_ip"); clientIP != "" {
filters["client_ip"] = clientIP
}
if requestID := query.Get("request_id"); requestID != "" {
filters["request_id"] = requestID
}
if method := query.Get("method"); method != "" {
filters["method"] = method
}
if path := query.Get("path"); path != "" {
filters["path"] = path
}
if search := query.Get("search"); search != "" {
filters["search"] = search
}
if startTime := query.Get("start_time"); startTime != "" {
filters["start_time"] = startTime
}
if endTime := query.Get("end_time"); endTime != "" {
filters["end_time"] = endTime
}
// Parse pagination
page := 1
if pageStr := query.Get("page"); pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
pageSize := 20
if pageSizeStr := query.Get("page_size"); pageSizeStr != "" {
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 {
pageSize = ps
}
}
// Get audit events with forwarded browser metadata
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
events, pagination, err := apiClientIface.GetAuditEvents(ctx, filters, page, pageSize)
if err != nil {
log.Printf("Failed to get audit events: %v", err)
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load audit events",
"details": err.Error(),
})
return
}
c.HTML(http.StatusOK, "audit/list.html", gin.H{
"title": "Audit Events",
"active": "audit",
"events": events,
"pagination": pagination,
"filters": filters,
"query": query,
})
}
func (s *Server) handleAuditShow(c *gin.Context) {
// Get API client from context
apiClientIface := c.MustGet("apiClient").(APIClientInterface)
// Get audit event ID from URL
id := c.Param("id")
if id == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid audit event ID",
"details": "Audit event ID is required",
})
return
}
// Get audit event with forwarded browser metadata
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
event, err := apiClientIface.GetAuditEvent(ctx, id)
if err != nil {
log.Printf("Failed to get audit event %s: %v", id, err)
if strings.Contains(err.Error(), "not found") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Audit event not found",
"details": fmt.Sprintf("Audit event with ID %s was not found", id),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load audit event",
"details": err.Error(),
})
}
return
}
c.HTML(http.StatusOK, "audit/show.html", gin.H{
"title": "Audit Event",
"active": "audit",
"event": event,
})
}
// Token edit/revoke handlers
func (s *Server) handleTokensEdit(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
token, err := apiClient.GetToken(ctx, tokenID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load token",
"details": err.Error(),
})
}
return
}
// Get project for the token
project, err := apiClient.GetProject(ctx, token.ProjectID)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load project",
"details": err.Error(),
})
return
}
c.HTML(http.StatusOK, "tokens/edit.html", gin.H{
"title": "Edit Token",
"active": "tokens",
"token": token,
"project": project,
"tokenID": tokenID,
})
}
func (s *Server) handleTokensUpdate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
var req struct {
IsActive *bool `form:"is_active"`
MaxRequests *int `form:"max_requests"`
}
if err := c.ShouldBind(&req); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid form data",
"details": err.Error(),
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
_, err := apiClient.UpdateToken(ctx, tokenID, req.IsActive, req.MaxRequests)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to update token",
"details": err.Error(),
})
}
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/tokens/%s", tokenID))
}
func (s *Server) handleTokensRevoke(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
err := apiClient.RevokeToken(ctx, tokenID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to revoke token",
"details": err.Error(),
})
}
return
}
c.Redirect(http.StatusSeeOther, "/tokens")
}
// handleTokensPostOverride routes POST requests with _method overrides for token actions.
func (s *Server) handleTokensPostOverride(c *gin.Context) {
// Parse form to access _method
if err := c.Request.ParseForm(); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Failed to parse form data",
})
return
}
method := c.PostForm("_method")
switch strings.ToUpper(method) {
case http.MethodPut:
s.handleTokensUpdate(c)
return
case http.MethodDelete:
s.handleTokensRevoke(c)
return
default:
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Unsupported method override for token action",
})
return
}
}
// Project bulk revoke handler
func (s *Server) handleProjectsBulkRevoke(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
projectID := c.Param("id")
if projectID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Project ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
err := apiClient.RevokeProjectTokens(ctx, projectID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Project not found",
"details": fmt.Sprintf("Project %s was not found", projectID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to revoke project tokens",
"details": err.Error(),
})
}
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/projects/%s", projectID))
}
// Package api provides shared utility functions for CLI and API clients.
package api
import (
"fmt"
"os"
"strings"
"time"
"github.com/spf13/cobra"
)
// ObfuscateKey obfuscates a sensitive key for display.
func ObfuscateKey(key string) string {
if len(key) <= 8 {
return "****"
}
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
}
// ParseTimeHeader parses a time header in RFC3339Nano format.
func ParseTimeHeader(val string) time.Time {
if val == "" {
return time.Time{}
}
t, err := time.Parse(time.RFC3339Nano, val)
if err != nil {
return time.Time{}
}
return t
}
// GetManagementToken gets the management token from a cobra command or environment.
func GetManagementToken(cmd *cobra.Command) (string, error) {
mgmtToken, _ := cmd.Flags().GetString("management-token")
if mgmtToken == "" {
mgmtToken = os.Getenv("MANAGEMENT_TOKEN")
}
if mgmtToken == "" {
return "", fmt.Errorf("management token is required (set MANAGEMENT_TOKEN env or use --management-token)")
}
return mgmtToken, nil
}
package audit
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sync"
)
// Logger handles writing audit events to multiple backends (file and database)
// with immutable semantics. It provides thread-safe operations and ensures
// all audit events are persisted.
type Logger struct {
file *os.File
writer io.Writer
mutex sync.Mutex
path string
dbStore DatabaseStore
dbEnabled bool
}
// DatabaseStore defines the interface for database audit storage
type DatabaseStore interface {
StoreAuditEvent(ctx context.Context, event *Event) error
}
// LoggerConfig holds configuration for the audit logger
type LoggerConfig struct {
// FilePath is the path to the audit log file
FilePath string
// CreateDir determines whether to create parent directories if they don't exist
CreateDir bool
// DatabaseStore is an optional database backend for audit events
DatabaseStore DatabaseStore
// EnableDatabase determines whether to store events in database
EnableDatabase bool
}
// NewLogger creates a new audit logger that writes to the specified file.
// If createDir is true, it will create parent directories if they don't exist.
// If a database store is provided, events will also be persisted to the database.
func NewLogger(config LoggerConfig) (*Logger, error) {
if config.FilePath == "" {
return nil, fmt.Errorf("audit log file path cannot be empty")
}
// Create parent directories if needed
if config.CreateDir {
dir := filepath.Dir(config.FilePath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create audit log directory: %w", err)
}
}
// Open file for appending with appropriate permissions
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open audit log file: %w", err)
}
return &Logger{
file: file,
writer: file,
path: config.FilePath,
dbStore: config.DatabaseStore,
dbEnabled: config.EnableDatabase && config.DatabaseStore != nil,
}, nil
}
// Log writes an audit event to both file and database backends.
// Events are written as JSON lines (JSONL format) for easy parsing.
// This method is thread-safe.
func (l *Logger) Log(event *Event) error {
if event == nil {
return fmt.Errorf("audit event cannot be nil")
}
l.mutex.Lock()
defer l.mutex.Unlock()
// Write to file backend
if err := l.logToFile(event); err != nil {
return fmt.Errorf("failed to log to file: %w", err)
}
// Write to database backend if enabled
if l.dbEnabled {
// Use background context with timeout for database operations
ctx := context.Background()
if err := l.dbStore.StoreAuditEvent(ctx, event); err != nil {
// Avoid hard dependency on a logger; keep stdout warning for now but
// ensure it's concise. In production, this would be wired to a structured logger.
_, _ = io.WriteString(os.Stdout, "audit: failed to store event in database\n")
}
}
return nil
}
// logToFile writes an audit event to the file backend
func (l *Logger) logToFile(event *Event) error {
// Encode event as JSON
data, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal audit event: %w", err)
}
// Write JSON line followed by newline
if _, err := l.writer.Write(data); err != nil {
return fmt.Errorf("failed to write audit event: %w", err)
}
if _, err := l.writer.Write([]byte("\n")); err != nil {
return fmt.Errorf("failed to write newline: %w", err)
}
// Sync to ensure data is written to disk
if syncer, ok := l.writer.(interface{ Sync() error }); ok {
if err := syncer.Sync(); err != nil {
return fmt.Errorf("failed to sync audit log: %w", err)
}
}
return nil
}
// Close closes the audit log file.
// After calling Close, the logger should not be used for logging.
func (l *Logger) Close() error {
l.mutex.Lock()
defer l.mutex.Unlock()
if l.file != nil {
err := l.file.Close()
l.file = nil
return err
}
return nil
}
// GetPath returns the file path of the audit log
func (l *Logger) GetPath() string {
return l.path
}
// NewNullLogger creates a logger that discards all events.
// Useful for testing or when audit logging is disabled.
func NewNullLogger() *Logger {
return &Logger{
writer: io.Discard,
dbEnabled: false,
}
}
// Package audit provides audit logging functionality for security-sensitive events
// in the LLM proxy. It implements a separate audit sink with immutable semantics
// for compliance and security investigations.
package audit
import (
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// Event represents a security audit event with canonical fields.
// All audit events must include these core fields for compliance and investigation purposes.
type Event struct {
// Timestamp when the event occurred (ISO8601 format)
Timestamp time.Time `json:"timestamp"`
// Action describes what operation was performed (e.g., "token.create", "project.delete")
Action string `json:"action"`
// Actor identifies who performed the action (user ID, token ID, or system)
Actor string `json:"actor"`
// ProjectID identifies which project was affected (if applicable)
ProjectID string `json:"project_id,omitempty"`
// RequestID for correlation with request logs
RequestID string `json:"request_id,omitempty"`
// CorrelationID for tracing across services
CorrelationID string `json:"correlation_id,omitempty"`
// ClientIP is the IP address of the client making the request
ClientIP string `json:"client_ip,omitempty"`
// Result indicates success or failure of the operation
Result ResultType `json:"result"`
// Details contains additional context about the event (no secrets)
Details map[string]interface{} `json:"details,omitempty"`
}
// ResultType represents the outcome of an audited operation
type ResultType string
const (
// ResultSuccess indicates the operation completed successfully
ResultSuccess ResultType = "success"
// ResultFailure indicates the operation failed
ResultFailure ResultType = "failure"
// ResultDenied indicates the operation was denied (e.g., due to policy)
ResultDenied ResultType = "denied"
// ResultError indicates the operation encountered an error
ResultError ResultType = "error"
)
// Action constants for standardized audit event types
const (
// Token lifecycle actions
ActionTokenCreate = "token.create"
ActionTokenRead = "token.read"
ActionTokenUpdate = "token.update"
ActionTokenRevoke = "token.revoke"
ActionTokenRevokeBatch = "token.revoke_batch"
ActionTokenDelete = "token.delete"
ActionTokenList = "token.list"
ActionTokenValidate = "token.validate"
ActionTokenAccess = "token.access"
// Project lifecycle actions
ActionProjectCreate = "project.create"
ActionProjectRead = "project.read"
ActionProjectUpdate = "project.update"
ActionProjectDelete = "project.delete"
ActionProjectList = "project.list"
// Proxy request actions
ActionProxyRequest = "proxy_request"
// Admin actions
ActionAdminLogin = "admin.login"
ActionAdminLogout = "admin.logout"
ActionAdminAccess = "admin.access"
// Audit actions
ActionAuditList = "audit.list"
ActionAuditShow = "audit.show"
// Cache actions
ActionCachePurge = "cache.purge"
)
// Actor types for common audit actors
const (
ActorSystem = "system"
ActorAnonymous = "anonymous"
ActorAdmin = "admin"
ActorManagement = "management_api"
)
// NewEvent creates a new audit event with the specified action and result.
// The timestamp is automatically set to the current time.
func NewEvent(action string, actor string, result ResultType) *Event {
return &Event{
Timestamp: time.Now().UTC(),
Action: action,
Actor: actor,
Result: result,
Details: make(map[string]interface{}),
}
}
// WithProjectID sets the project ID for the audit event
func (e *Event) WithProjectID(projectID string) *Event {
e.ProjectID = projectID
return e
}
// WithRequestID sets the request ID for correlation with request logs
func (e *Event) WithRequestID(requestID string) *Event {
e.RequestID = requestID
return e
}
// WithCorrelationID sets the correlation ID for tracing across services
func (e *Event) WithCorrelationID(correlationID string) *Event {
e.CorrelationID = correlationID
return e
}
// WithClientIP sets the client IP address for the audit event
func (e *Event) WithClientIP(clientIP string) *Event {
e.ClientIP = clientIP
return e
}
// WithDetail adds a detail key-value pair to the audit event.
// Secrets and sensitive information should be obfuscated before calling this method.
func (e *Event) WithDetail(key string, value interface{}) *Event {
if e.Details == nil {
e.Details = make(map[string]interface{})
}
e.Details[key] = value
return e
}
// WithTokenID adds an obfuscated token ID to the audit event details
func (e *Event) WithTokenID(token string) *Event {
return e.WithDetail("token_id", obfuscate.ObfuscateTokenGeneric(token))
}
// WithError adds error information to the audit event details
func (e *Event) WithError(err error) *Event {
if err != nil {
return e.WithDetail("error", err.Error())
}
return e
}
// WithIPAddress adds the client IP address to the audit event details
// Deprecated: Use WithClientIP instead for first-class IP field support
func (e *Event) WithIPAddress(ip string) *Event {
return e.WithDetail("ip_address", ip)
}
// WithUserAgent adds the user agent to the audit event details
func (e *Event) WithUserAgent(userAgent string) *Event {
return e.WithDetail("user_agent", userAgent)
}
// WithHTTPMethod adds the HTTP method to the audit event details
func (e *Event) WithHTTPMethod(method string) *Event {
return e.WithDetail("http_method", method)
}
// WithEndpoint adds the API endpoint to the audit event details
func (e *Event) WithEndpoint(endpoint string) *Event {
return e.WithDetail("endpoint", endpoint)
}
// WithDuration adds the operation duration to the audit event details
func (e *Event) WithDuration(duration time.Duration) *Event {
return e.WithDetail("duration_ms", duration.Milliseconds())
}
// WithReason adds a reason/description to the audit event details
func (e *Event) WithReason(reason string) *Event {
return e.WithDetail("reason", reason)
}
// Package client provides HTTP client functionality for communicating with the LLM Proxy API.
package client
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/chzyer/readline"
)
// ChatMessage represents a message in the chat
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ChatRequest represents a request to the chat API
type ChatRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
}
// ChatResponse represents a response from the chat API
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ChatCompletionStreamResponse represents a chunked response in a stream
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ChatClient handles communication with the LLM Proxy chat API
type ChatClient struct {
BaseURL string
Token string
HTTPClient *http.Client
}
// ChatOptions configures chat request parameters
type ChatOptions struct {
Model string
Temperature float64
MaxTokens int
UseStreaming bool
VerboseMode bool
}
// NewChatClient creates a new chat client
func NewChatClient(baseURL, token string) *ChatClient {
return &ChatClient{
BaseURL: baseURL,
Token: token,
HTTPClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
// SendChatRequest sends a chat request and returns the response
func (c *ChatClient) SendChatRequest(messages []ChatMessage, options ChatOptions, readline *readline.Instance) (*ChatResponse, error) {
if c.Token == "" {
return nil, fmt.Errorf("token is required")
}
// Validate proxy URL
if _, err := url.Parse(c.BaseURL); err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
request := ChatRequest{
Model: options.Model,
Messages: messages,
Temperature: options.Temperature,
MaxTokens: options.MaxTokens,
Stream: options.UseStreaming,
}
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
if options.VerboseMode {
fmt.Printf("Request: %s\n", string(jsonData))
}
req, err := http.NewRequest("POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.Token)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Fprintf(os.Stderr, "failed to close response body: %v\n", err)
}
}()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body))
}
if options.UseStreaming {
return c.handleStreamingResponse(resp, readline, options.VerboseMode)
}
return c.handleNonStreamingResponse(resp, options.VerboseMode)
}
// handleStreamingResponse processes streaming chat responses
func (c *ChatClient) handleStreamingResponse(resp *http.Response, readline *readline.Instance, verbose bool) (*ChatResponse, error) {
scanner := bufio.NewScanner(resp.Body)
var finalResponse *ChatResponse
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var streamResp ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
if verbose {
fmt.Printf("Failed to parse stream data: %v\n", err)
}
continue
}
if len(streamResp.Choices) > 0 {
choice := streamResp.Choices[0]
if choice.Delta.Content != "" {
if readline != nil && readline.Config.Stdout != nil {
if _, err := readline.Config.Stdout.Write([]byte(choice.Delta.Content)); err != nil {
return nil, fmt.Errorf("failed to write streaming content: %w", err)
}
} else {
fmt.Print(choice.Delta.Content)
}
}
// Convert to final response format
if finalResponse == nil {
finalResponse = &ChatResponse{
ID: streamResp.ID,
Object: streamResp.Object,
Created: streamResp.Created,
Model: streamResp.Model,
Choices: []struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}{
{
Index: choice.Index,
Message: ChatMessage{
Role: "assistant",
Content: "",
},
FinishReason: choice.FinishReason,
},
},
Usage: streamResp.Usage,
}
}
// Accumulate content
finalResponse.Choices[0].Message.Content += choice.Delta.Content
if choice.FinishReason != "" {
finalResponse.Choices[0].FinishReason = choice.FinishReason
}
}
}
if readline != nil && readline.Config.Stdout != nil {
if _, err := readline.Config.Stdout.Write([]byte("\n")); err != nil {
return nil, fmt.Errorf("failed to write newline after streaming: %w", err)
}
} else {
fmt.Println()
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("stream reading error: %w", err)
}
if finalResponse == nil {
return nil, fmt.Errorf("no response received from stream")
}
return finalResponse, nil
}
// handleNonStreamingResponse processes non-streaming chat responses
func (c *ChatClient) handleNonStreamingResponse(resp *http.Response, verbose bool) (*ChatResponse, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if verbose {
fmt.Printf("Response: %s\n", string(body))
}
var chatResp ChatResponse
if err := json.Unmarshal(body, &chatResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &chatResp, nil
}
// Package config handles application configuration loading and validation
// from environment variables, providing a type-safe configuration structure.
package config
import (
"fmt"
"os"
"strconv"
"strings"
"time"
)
// Config holds all application configuration values loaded from environment variables.
// It provides a centralized, type-safe way to access configuration throughout the application.
type Config struct {
// Server configuration
ListenAddr string // Address to listen on (e.g., ":8080")
RequestTimeout time.Duration // Timeout for upstream API requests
MaxRequestSize int64 // Maximum size of incoming requests in bytes
MaxConcurrentReqs int // Maximum number of concurrent requests
// Environment
APIEnv string // API environment: 'production', 'development', 'test'
// Database configuration
DatabasePath string // Path to the SQLite database file
DatabasePoolSize int // Number of connections in the database pool
// Authentication
ManagementToken string // Token for admin operations, used to access the management API
// API Provider configuration
APIConfigPath string // Path to the API providers configuration file
DefaultAPIProvider string // Default API provider to use
OpenAIAPIURL string // Base URL for OpenAI API (legacy support)
EnableStreaming bool // Whether to enable streaming responses from APIs
// Admin UI settings
AdminUIPath string // Base path for the admin UI
AdminUI AdminUIConfig // Admin UI server configuration
// Logging
LogLevel string // Log level (debug, info, warn, error)
LogFormat string // Log format (json, text)
LogFile string // Path to log file (empty for stdout)
// Audit logging
AuditEnabled bool // Enable audit logging for security events
AuditLogFile string // Path to audit log file (empty to disable)
AuditCreateDir bool // Create parent directories for audit log file
AuditStoreInDB bool // Store audit events in database for analytics
// Observability middleware
ObservabilityEnabled bool // Enable async observability middleware
ObservabilityBufferSize int // Buffer size for in-memory event bus
// CORS settings
CORSAllowedOrigins []string // Allowed origins for CORS
CORSAllowedMethods []string // Allowed methods for CORS
CORSAllowedHeaders []string // Allowed headers for CORS
CORSMaxAge time.Duration // Max age for CORS preflight responses
// Rate limiting
GlobalRateLimit int // Maximum requests per minute globally
IPRateLimit int // Maximum requests per minute per IP
// Distributed rate limiting
DistributedRateLimitEnabled bool // Enable Redis-backed distributed rate limiting
DistributedRateLimitPrefix string // Redis key prefix for rate limit counters
DistributedRateLimitKeySecret string // HMAC secret for hashing token IDs in Redis keys (security)
DistributedRateLimitWindow time.Duration // Sliding window duration for rate limiting
DistributedRateLimitMax int // Maximum requests per window
DistributedRateLimitFallback bool // Enable fallback to in-memory when Redis unavailable
// Monitoring
EnableMetrics bool // Whether to enable a lightweight metrics endpoint (provider-agnostic)
MetricsPath string // Path for metrics endpoint
// Cleanup
TokenCleanupInterval time.Duration // Interval for cleaning up expired tokens
// Project active guard configuration
EnforceProjectActive bool // Whether to enforce project active status (default: true)
ActiveCacheTTL time.Duration // TTL for project active status cache (e.g., 5s)
ActiveCacheMax int // Maximum entries in project active status cache (e.g., 10000)
// Event bus configuration
EventBusBackend string // Backend for event bus: "redis", "redis-streams", or "in-memory"
RedisAddr string // Redis server address (e.g., "localhost:6379")
RedisDB int // Redis database number (default: 0)
// Redis Streams configuration (when EventBusBackend = "redis-streams")
RedisStreamKey string // Redis stream key name (default: "llm-proxy-events")
RedisConsumerGroup string // Consumer group name (default: "llm-proxy-dispatchers")
RedisConsumerName string // Consumer name within the group (should be unique per instance)
RedisStreamMaxLen int64 // Max stream length (0 = unlimited, default: 10000)
RedisStreamBlockTime time.Duration // Block timeout for reading (default: 5s)
RedisStreamClaimTime time.Duration // Min idle time before claiming pending msgs (default: 30s)
RedisStreamBatchSize int64 // Batch size for reading messages (default: 100)
// Cache stats aggregation
CacheStatsBufferSize int // Buffer size for async cache stats aggregation (default: 1000)
}
// AdminUIConfig holds configuration for the Admin UI server
type AdminUIConfig struct {
ListenAddr string // Address for admin UI server to listen on
APIBaseURL string // Base URL of the Management API
ManagementToken string // Token for accessing Management API
Enabled bool // Whether Admin UI is enabled
TemplateDir string // Directory for HTML templates (default: "web/templates")
}
// New creates a new configuration with values from environment variables.
// It applies default values where environment variables are not set,
// and validates required configuration settings.
//
// Returns a populated Config struct and nil error on success,
// or nil and an error if validation fails.
func New() (*Config, error) {
config := &Config{
// Server defaults
ListenAddr: getEnvString("LISTEN_ADDR", ":8080"),
RequestTimeout: getEnvDuration("REQUEST_TIMEOUT", 30*time.Second),
MaxRequestSize: getEnvInt64("MAX_REQUEST_SIZE", 10*1024*1024), // 10MB
MaxConcurrentReqs: getEnvInt("MAX_CONCURRENT_REQUESTS", 100),
// Environment
APIEnv: getEnvString("API_ENV", "development"),
// Database defaults
DatabasePath: getEnvString("DATABASE_PATH", "./data/llm-proxy.db"),
DatabasePoolSize: getEnvInt("DATABASE_POOL_SIZE", 10),
// Authentication
ManagementToken: getEnvString("MANAGEMENT_TOKEN", ""),
// API Provider settings
APIConfigPath: getEnvString("API_CONFIG_PATH", "./config/api_providers.yaml"),
DefaultAPIProvider: getEnvString("DEFAULT_API_PROVIDER", "openai"),
OpenAIAPIURL: getEnvString("OPENAI_API_URL", "https://api.openai.com"),
EnableStreaming: getEnvBool("ENABLE_STREAMING", true),
// Admin UI settings
AdminUIPath: getEnvString("ADMIN_UI_PATH", "/admin"),
AdminUI: AdminUIConfig{
ListenAddr: getEnvString("ADMIN_UI_LISTEN_ADDR", ":8081"),
APIBaseURL: getEnvString("ADMIN_UI_API_BASE_URL", "http://localhost:8080"),
ManagementToken: getEnvString("MANAGEMENT_TOKEN", ""),
Enabled: getEnvBool("ADMIN_UI_ENABLED", true),
TemplateDir: getEnvString("ADMIN_UI_TEMPLATE_DIR", "web/templates"),
},
// Logging defaults
LogLevel: getEnvString("LOG_LEVEL", "info"),
LogFormat: getEnvString("LOG_FORMAT", "json"),
LogFile: getEnvString("LOG_FILE", ""),
// Audit logging defaults
AuditEnabled: getEnvBool("AUDIT_ENABLED", true),
AuditLogFile: getEnvString("AUDIT_LOG_FILE", "./data/audit.log"),
AuditCreateDir: getEnvBool("AUDIT_CREATE_DIR", true),
AuditStoreInDB: getEnvBool("AUDIT_STORE_IN_DB", true),
ObservabilityEnabled: getEnvBool("OBSERVABILITY_ENABLED", true),
ObservabilityBufferSize: getEnvInt("OBSERVABILITY_BUFFER_SIZE", 1000),
// CORS defaults
CORSAllowedOrigins: getEnvStringSlice("CORS_ALLOWED_ORIGINS", []string{"*"}),
CORSAllowedMethods: getEnvStringSlice("CORS_ALLOWED_METHODS", []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}),
CORSAllowedHeaders: getEnvStringSlice("CORS_ALLOWED_HEADERS", []string{"Authorization", "Content-Type"}),
CORSMaxAge: getEnvDuration("CORS_MAX_AGE", 24*time.Hour),
// Rate limiting defaults
GlobalRateLimit: getEnvInt("GLOBAL_RATE_LIMIT", 100),
IPRateLimit: getEnvInt("IP_RATE_LIMIT", 30),
// Distributed rate limiting defaults
DistributedRateLimitEnabled: getEnvBool("DISTRIBUTED_RATE_LIMIT_ENABLED", false),
DistributedRateLimitPrefix: getEnvString("DISTRIBUTED_RATE_LIMIT_PREFIX", "ratelimit:"),
DistributedRateLimitKeySecret: getEnvString("DISTRIBUTED_RATE_LIMIT_KEY_SECRET", ""),
DistributedRateLimitWindow: getEnvDuration("DISTRIBUTED_RATE_LIMIT_WINDOW", time.Minute),
DistributedRateLimitMax: getEnvInt("DISTRIBUTED_RATE_LIMIT_MAX", 60),
DistributedRateLimitFallback: getEnvBool("DISTRIBUTED_RATE_LIMIT_FALLBACK", true),
// Monitoring defaults
EnableMetrics: getEnvBool("ENABLE_METRICS", true),
MetricsPath: getEnvString("METRICS_PATH", "/metrics"),
// Cleanup defaults
TokenCleanupInterval: getEnvDuration("TOKEN_CLEANUP_INTERVAL", time.Hour),
// Project active guard defaults
EnforceProjectActive: getEnvBool("LLM_PROXY_ENFORCE_PROJECT_ACTIVE", true),
ActiveCacheTTL: getEnvDuration("LLM_PROXY_ACTIVE_CACHE_TTL", 5*time.Second),
ActiveCacheMax: getEnvInt("LLM_PROXY_ACTIVE_CACHE_MAX", 10000),
// Event bus configuration
EventBusBackend: getEnvString("LLM_PROXY_EVENT_BUS", "redis"),
RedisAddr: getEnvString("REDIS_ADDR", "localhost:6379"),
RedisDB: getEnvInt("REDIS_DB", 0),
// Redis Streams configuration
RedisStreamKey: getEnvString("REDIS_STREAM_KEY", "llm-proxy-events"),
RedisConsumerGroup: getEnvString("REDIS_CONSUMER_GROUP", "llm-proxy-dispatchers"),
RedisConsumerName: getEnvString("REDIS_CONSUMER_NAME", ""),
RedisStreamMaxLen: getEnvInt64("REDIS_STREAM_MAX_LEN", 10000),
RedisStreamBlockTime: getEnvDuration("REDIS_STREAM_BLOCK_TIME", 5*time.Second),
RedisStreamClaimTime: getEnvDuration("REDIS_STREAM_CLAIM_TIME", 30*time.Second),
RedisStreamBatchSize: getEnvInt64("REDIS_STREAM_BATCH_SIZE", 100),
// Cache stats aggregation
CacheStatsBufferSize: getEnvInt("CACHE_STATS_BUFFER_SIZE", 1000),
}
// Validate required settings
if config.ManagementToken == "" {
return nil, fmt.Errorf("MANAGEMENT_TOKEN environment variable is required")
}
return config, nil
}
// getEnvString retrieves a string value from an environment variable,
// falling back to the provided default value if the variable is not set.
func getEnvString(key, defaultValue string) string {
if value, exists := os.LookupEnv(key); exists {
return value
}
return defaultValue
}
// getEnvBool retrieves a boolean value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as a boolean.
func getEnvBool(key string, defaultValue bool) bool {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := strconv.ParseBool(value)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvInt retrieves an integer value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as an integer.
func getEnvInt(key string, defaultValue int) int {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := strconv.Atoi(value)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvInt64 retrieves a 64-bit integer value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as a 64-bit integer.
func getEnvInt64(key string, defaultValue int64) int64 {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := strconv.ParseInt(value, 10, 64)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvDuration retrieves a duration value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as a duration.
func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := time.ParseDuration(value)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvStringSlice retrieves a comma-separated string value from an environment variable
// and splits it into a slice of strings, falling back to the provided default value
// if the variable is not set or is empty.
func getEnvStringSlice(key string, defaultValue []string) []string {
if value, exists := os.LookupEnv(key); exists && value != "" {
parts := strings.Split(value, ",")
for i := range parts {
parts[i] = strings.TrimSpace(parts[i])
}
return parts
}
return defaultValue
}
// LoadFromFile loads configuration from a file (placeholder for future YAML/JSON support)
func LoadFromFile(path string) (*Config, error) {
// For now, return default config - file loading can be implemented later
return DefaultConfig(), nil
}
// DefaultConfig returns a configuration with default values
func DefaultConfig() *Config {
return &Config{
// Server defaults
ListenAddr: ":8080",
RequestTimeout: 30 * time.Second,
MaxRequestSize: 10 * 1024 * 1024, // 10MB
MaxConcurrentReqs: 100,
// Environment
APIEnv: "development",
// Database defaults
DatabasePath: "./data/llm-proxy.db",
DatabasePoolSize: 10,
// API Provider settings
APIConfigPath: "./config/api_providers.yaml",
DefaultAPIProvider: "openai",
OpenAIAPIURL: "https://api.openai.com",
EnableStreaming: true,
// Admin UI settings
AdminUIPath: "/admin",
AdminUI: AdminUIConfig{
ListenAddr: ":8081",
APIBaseURL: "http://localhost:8080",
ManagementToken: "",
Enabled: true,
TemplateDir: "web/templates",
},
// Logging defaults
LogLevel: "info",
LogFormat: "json",
LogFile: "",
ObservabilityEnabled: true,
ObservabilityBufferSize: 1000,
// CORS defaults
CORSAllowedOrigins: []string{"*"},
CORSAllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
CORSMaxAge: 24 * time.Hour,
// Rate limiting defaults
GlobalRateLimit: 100,
IPRateLimit: 30,
// Distributed rate limiting defaults
DistributedRateLimitEnabled: false,
DistributedRateLimitPrefix: "ratelimit:",
DistributedRateLimitKeySecret: "",
DistributedRateLimitWindow: time.Minute,
DistributedRateLimitMax: 60,
DistributedRateLimitFallback: true,
// Monitoring defaults
EnableMetrics: true,
MetricsPath: "/metrics",
// Cleanup defaults
TokenCleanupInterval: time.Hour,
// Project active guard defaults
EnforceProjectActive: true,
ActiveCacheTTL: 5 * time.Second,
ActiveCacheMax: 10000,
// Event bus configuration
EventBusBackend: "redis",
RedisAddr: "localhost:6379",
RedisDB: 0,
// Redis Streams configuration
RedisStreamKey: "llm-proxy-events",
RedisConsumerGroup: "llm-proxy-dispatchers",
RedisConsumerName: "",
RedisStreamMaxLen: 10000,
RedisStreamBlockTime: 5 * time.Second,
RedisStreamClaimTime: 30 * time.Second,
RedisStreamBatchSize: 100,
// Cache stats aggregation
CacheStatsBufferSize: 1000,
}
}
package config
import (
"os"
"strconv"
)
// envOrDefault returns the value of the environment variable if set, otherwise the fallback.
func EnvOrDefault(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
// EnvIntOrDefault returns the int value of the environment variable if set and valid, otherwise the fallback.
func EnvIntOrDefault(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
// EnvBoolOrDefault returns the bool value of the environment variable if set and valid, otherwise the fallback.
func EnvBoolOrDefault(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil {
return b
}
}
return fallback
}
// EnvFloat64OrDefault returns the float64 value of the environment variable if set and valid, otherwise the fallback.
func EnvFloat64OrDefault(key string, fallback float64) float64 {
if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f
}
}
return fallback
}
package database
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/sofatutor/llm-proxy/internal/audit"
)
// AuditStore defines the interface for persisting audit events to database
type AuditStore interface {
StoreAuditEvent(ctx context.Context, event *audit.Event) error
ListAuditEvents(ctx context.Context, filters AuditEventFilters) ([]AuditEvent, error)
CountAuditEvents(ctx context.Context, filters AuditEventFilters) (int, error)
GetAuditEventByID(ctx context.Context, id string) (*AuditEvent, error)
}
// AuditEventFilters provides filtering options for audit event queries
type AuditEventFilters struct {
Action string
ClientIP string
ProjectID string
StartTime *string // RFC3339 format
EndTime *string // RFC3339 format
Outcome string
Actor string
RequestID string
CorrelationID string
Method string
Path string
Search string // Full-text search over reason/metadata
Limit int
Offset int
}
// StoreAuditEvent persists an audit event to the database
func (d *DB) StoreAuditEvent(ctx context.Context, event *audit.Event) error {
if d == nil || d.db == nil {
return fmt.Errorf("database is nil")
}
if event == nil {
return fmt.Errorf("audit event cannot be nil")
}
// Generate UUID for the audit event
id := uuid.New().String()
// Convert metadata to JSON string if present
var metadataJSON *string
if len(event.Details) > 0 {
metadataBytes, err := json.Marshal(event.Details)
if err != nil {
return fmt.Errorf("failed to marshal audit event metadata: %w", err)
}
metadataStr := string(metadataBytes)
metadataJSON = &metadataStr
}
// Extract common fields from details for first-class columns
var method, path, userAgent, reason, tokenID *string
if event.Details != nil {
if v, ok := event.Details["http_method"].(string); ok {
method = &v
}
if v, ok := event.Details["endpoint"].(string); ok {
path = &v
}
if v, ok := event.Details["user_agent"].(string); ok {
userAgent = &v
}
if v, ok := event.Details["error"].(string); ok {
reason = &v
}
if v, ok := event.Details["token_id"].(string); ok {
tokenID = &v
}
}
// Convert optional string fields to pointers
var projectID, requestID, correlationID, clientIP *string
if event.ProjectID != "" {
projectID = &event.ProjectID
}
if event.RequestID != "" {
requestID = &event.RequestID
}
if event.CorrelationID != "" {
correlationID = &event.CorrelationID
}
if event.ClientIP != "" {
clientIP = &event.ClientIP
}
query := `
INSERT INTO audit_events (
id, timestamp, action, actor, project_id, request_id, correlation_id,
client_ip, method, path, user_agent, outcome, reason, token_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := d.ExecContextRebound(ctx, query,
id,
event.Timestamp,
event.Action,
event.Actor,
projectID,
requestID,
correlationID,
clientIP,
method,
path,
userAgent,
string(event.Result),
reason,
tokenID,
metadataJSON,
)
if err != nil {
return fmt.Errorf("failed to insert audit event: %w", err)
}
return nil
}
// ListAuditEvents retrieves audit events from the database with optional filtering
func (d *DB) ListAuditEvents(ctx context.Context, filters AuditEventFilters) ([]AuditEvent, error) {
query := "SELECT id, timestamp, action, actor, project_id, request_id, correlation_id, client_ip, method, path, user_agent, outcome, reason, token_id, metadata FROM audit_events WHERE 1=1"
args := []interface{}{}
// Apply filters
if filters.Action != "" {
query += " AND action = ?"
args = append(args, filters.Action)
}
if filters.ClientIP != "" {
query += " AND client_ip = ?"
args = append(args, filters.ClientIP)
}
if filters.ProjectID != "" {
query += " AND project_id = ?"
args = append(args, filters.ProjectID)
}
if filters.Outcome != "" {
query += " AND outcome = ?"
args = append(args, filters.Outcome)
}
if filters.Actor != "" {
query += " AND actor = ?"
args = append(args, filters.Actor)
}
if filters.RequestID != "" {
query += " AND request_id = ?"
args = append(args, filters.RequestID)
}
if filters.CorrelationID != "" {
query += " AND correlation_id = ?"
args = append(args, filters.CorrelationID)
}
if filters.Method != "" {
query += " AND method = ?"
args = append(args, filters.Method)
}
if filters.Path != "" {
query += " AND path = ?"
args = append(args, filters.Path)
}
if filters.Search != "" {
query += " AND (request_id LIKE ? OR correlation_id LIKE ? OR client_ip LIKE ? OR action LIKE ? OR actor LIKE ? OR method LIKE ? OR path LIKE ? OR reason LIKE ? OR metadata LIKE ?)"
searchPattern := "%" + filters.Search + "%"
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern)
}
if filters.StartTime != nil {
query += " AND timestamp >= datetime(?)"
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
query += " AND timestamp <= datetime(?)"
args = append(args, *filters.EndTime)
}
// Order by timestamp descending
query += " ORDER BY timestamp DESC"
// Apply limit and offset
if filters.Limit > 0 {
query += " LIMIT ?"
args = append(args, filters.Limit)
if filters.Offset > 0 {
query += " OFFSET ?"
args = append(args, filters.Offset)
}
}
rows, err := d.QueryContextRebound(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []AuditEvent
for rows.Next() {
var event AuditEvent
err := rows.Scan(
&event.ID,
&event.Timestamp,
&event.Action,
&event.Actor,
&event.ProjectID,
&event.RequestID,
&event.CorrelationID,
&event.ClientIP,
&event.Method,
&event.Path,
&event.UserAgent,
&event.Outcome,
&event.Reason,
&event.TokenID,
&event.Metadata,
)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit events: %w", err)
}
return events, nil
}
// CountAuditEvents returns the total count of audit events matching the given filters
func (d *DB) CountAuditEvents(ctx context.Context, filters AuditEventFilters) (int, error) {
query := "SELECT COUNT(*) FROM audit_events WHERE 1=1"
args := []interface{}{}
// Apply the same filters as ListAuditEvents (excluding limit/offset)
if filters.Action != "" {
query += " AND action = ?"
args = append(args, filters.Action)
}
if filters.ClientIP != "" {
query += " AND client_ip = ?"
args = append(args, filters.ClientIP)
}
if filters.ProjectID != "" {
query += " AND project_id = ?"
args = append(args, filters.ProjectID)
}
if filters.Outcome != "" {
query += " AND outcome = ?"
args = append(args, filters.Outcome)
}
if filters.Actor != "" {
query += " AND actor = ?"
args = append(args, filters.Actor)
}
if filters.RequestID != "" {
query += " AND request_id = ?"
args = append(args, filters.RequestID)
}
if filters.CorrelationID != "" {
query += " AND correlation_id = ?"
args = append(args, filters.CorrelationID)
}
if filters.Method != "" {
query += " AND method = ?"
args = append(args, filters.Method)
}
if filters.Path != "" {
query += " AND path = ?"
args = append(args, filters.Path)
}
if filters.Search != "" {
query += " AND (request_id LIKE ? OR correlation_id LIKE ? OR client_ip LIKE ? OR action LIKE ? OR actor LIKE ? OR method LIKE ? OR path LIKE ? OR reason LIKE ? OR metadata LIKE ?)"
searchPattern := "%" + filters.Search + "%"
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern)
}
if filters.StartTime != nil {
query += " AND timestamp >= datetime(?)"
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
query += " AND timestamp <= datetime(?)"
args = append(args, *filters.EndTime)
}
var count int
err := d.QueryRowContextRebound(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count audit events: %w", err)
}
return count, nil
}
// GetAuditEventByID retrieves a specific audit event by its ID
func (d *DB) GetAuditEventByID(ctx context.Context, id string) (*AuditEvent, error) {
query := "SELECT id, timestamp, action, actor, project_id, request_id, correlation_id, client_ip, method, path, user_agent, outcome, reason, token_id, metadata FROM audit_events WHERE id = ?"
row := d.QueryRowContextRebound(ctx, query, id)
var event AuditEvent
err := row.Scan(
&event.ID,
&event.Timestamp,
&event.Action,
&event.Actor,
&event.ProjectID,
&event.RequestID,
&event.CorrelationID,
&event.ClientIP,
&event.Method,
&event.Path,
&event.UserAgent,
&event.Outcome,
&event.Reason,
&event.TokenID,
&event.Metadata,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("audit event not found")
}
return nil, fmt.Errorf("failed to get audit event: %w", err)
}
return &event, nil
}
// Package database provides SQLite database operations for the LLM Proxy.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"runtime"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"github.com/sofatutor/llm-proxy/internal/database/migrations"
)
// DB represents the database connection.
type DB struct {
db *sql.DB
driver DriverType
}
// Config contains the database configuration.
type Config struct {
// Path is the path to the SQLite database file.
Path string
// MaxOpenConns is the maximum number of open connections.
MaxOpenConns int
// MaxIdleConns is the maximum number of idle connections.
MaxIdleConns int
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
ConnMaxLifetime time.Duration
}
// DefaultConfig returns a default database configuration.
func DefaultConfig() Config {
return Config{
Path: "data/llm-proxy.db",
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
}
}
// New creates a new database connection.
func New(config Config) (*DB, error) {
// Ensure database directory exists
if err := ensureDirExists(filepath.Dir(config.Path)); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
// Open connection
db, err := sql.Open("sqlite3", config.Path+"?_journal=WAL&_foreign_keys=on")
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Configure connection pool
// Special case: in-memory SQLite databases are per-connection. Use a single connection
// to ensure schema and data are visible across queries within the same *sql.DB handle.
if config.Path == ":memory:" {
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
} else {
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
}
db.SetConnMaxLifetime(config.ConnMaxLifetime)
// Test the connection
if err := db.Ping(); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
// Run database migrations
if err := runMigrations(db); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return &DB{db: db, driver: DriverSQLite}, nil
}
// Close closes the database connection.
func (d *DB) Close() error {
if d.db != nil {
_ = d.db.Close()
}
return nil
}
// ensureDirExists creates the directory if it doesn't exist.
func ensureDirExists(dir string) error {
info, err := os.Stat(dir)
if errors.Is(err, fs.ErrNotExist) {
return os.MkdirAll(dir, 0755)
} else if err != nil {
return err
}
if !info.IsDir() {
return fmt.Errorf("path %s exists and is not a directory", dir)
}
return nil
}
// getMigrationsPath returns the path to the migrations directory.
// It tries multiple strategies to locate the migrations:
// 1. Relative path from current working directory (for development)
// 2. Path relative to this source file (for tests)
// 3. Relative path from executable location (for production)
// Debug logging is included to help diagnose path resolution issues in production.
func getMigrationsPath() (string, error) {
var triedPaths []string
// Try relative path from current working directory first (development)
relPath := "internal/database/migrations/sql"
triedPaths = append(triedPaths, relPath)
if _, err := os.Stat(relPath); err == nil {
return relPath, nil
}
// Try path relative to this source file (for tests)
_, filename, _, ok := runtime.Caller(0)
if ok {
// Get directory of this file (database.go)
sourceDir := filepath.Dir(filename)
// migrations/sql is sibling to database package
sourceRelPath := filepath.Join(sourceDir, "migrations", "sql")
triedPaths = append(triedPaths, sourceRelPath)
if _, err := os.Stat(sourceRelPath); err == nil {
return sourceRelPath, nil
}
}
// Try to get path relative to executable
execPath, err := os.Executable()
if err == nil {
execDir := filepath.Dir(execPath)
// Try relative to executable directory
execRelPath := filepath.Join(execDir, "internal/database/migrations/sql")
triedPaths = append(triedPaths, execRelPath)
if _, err := os.Stat(execRelPath); err == nil {
return execRelPath, nil
}
// Try relative to executable's parent (if executable is in bin/)
binRelPath := filepath.Join(filepath.Dir(execDir), "internal/database/migrations/sql")
triedPaths = append(triedPaths, binRelPath)
if _, err := os.Stat(binRelPath); err == nil {
return binRelPath, nil
}
}
return "", fmt.Errorf("migrations directory not found: tried paths %v", triedPaths)
}
// runMigrations runs database migrations using the migration runner.
func runMigrations(db *sql.DB) error {
migrationsPath, err := getMigrationsPath()
if err != nil {
return fmt.Errorf("failed to get migrations path: %w", err)
}
runner := migrations.NewMigrationRunner(db, migrationsPath)
if err := runner.Up(); err != nil {
return fmt.Errorf("failed to apply migrations: %w", err)
}
return nil
}
// initDatabase is deprecated. Use runMigrations instead.
// Kept for backward compatibility with DBInitForTests.
func initDatabase(db *sql.DB) error {
return runMigrations(db)
}
// DBInitForTests is a helper to ensure schema exists in tests. No-op if db is nil.
func DBInitForTests(d *DB) error {
if d == nil || d.db == nil {
return nil
}
return initDatabase(d.db)
}
// Transaction executes the given function within a transaction.
func (d *DB) Transaction(ctx context.Context, fn func(*sql.Tx) error) error {
if d == nil || d.db == nil {
return fmt.Errorf("database is nil")
}
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// If the function panics, rollback the transaction
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback()
panic(p) // Re-throw the panic after rolling back
}
}()
// Execute the function
if err := fn(tx); err != nil {
_ = tx.Rollback()
return err
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// DB returns the underlying sql.DB instance.
func (d *DB) DB() *sql.DB {
return d.db
}
// Driver returns the database driver type.
func (d *DB) Driver() DriverType {
return d.driver
}
// HealthCheck performs a health check on the database connection.
// It verifies that the database is reachable and responsive.
func (d *DB) HealthCheck(ctx context.Context) error {
if d == nil || d.db == nil {
return fmt.Errorf("database is nil")
}
// Test the connection with a simple query
if err := d.db.PingContext(ctx); err != nil {
return fmt.Errorf("database ping failed: %w", err)
}
// Verify we can execute a simple query
var result int
err := d.db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
if err != nil {
return fmt.Errorf("database query failed: %w", err)
}
return nil
}
// Package database provides database operations for the LLM Proxy.
package database
import (
"database/sql"
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/sofatutor/llm-proxy/internal/database/migrations"
)
// DriverType represents the database driver type.
type DriverType string
const (
// DriverSQLite represents the SQLite database driver.
DriverSQLite DriverType = "sqlite"
// DriverPostgres represents the PostgreSQL database driver.
DriverPostgres DriverType = "postgres"
)
// FullConfig contains the complete database configuration for all drivers.
type FullConfig struct {
// Driver specifies which database driver to use (sqlite, postgres).
Driver DriverType
// SQLite-specific configuration
// Path is the path to the SQLite database file.
Path string
// PostgreSQL-specific configuration
// DatabaseURL is the PostgreSQL connection string.
DatabaseURL string
// Connection pool settings (used by both drivers)
// MaxOpenConns is the maximum number of open connections.
MaxOpenConns int
// MaxIdleConns is the maximum number of idle connections.
MaxIdleConns int
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
ConnMaxLifetime time.Duration
}
// DefaultFullConfig returns a default database configuration.
func DefaultFullConfig() FullConfig {
return FullConfig{
Driver: DriverSQLite,
Path: "data/llm-proxy.db",
DatabaseURL: "",
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
}
}
// ConfigFromEnv creates a FullConfig from environment variables.
// Invalid configuration values are logged as warnings and defaults are used.
func ConfigFromEnv() FullConfig {
config := DefaultFullConfig()
if driver := os.Getenv("DB_DRIVER"); driver != "" {
driverType := DriverType(strings.ToLower(driver))
if driverType != DriverSQLite && driverType != DriverPostgres {
log.Printf("Warning: unsupported DB_DRIVER '%s', defaulting to sqlite", driver)
} else {
config.Driver = driverType
}
}
if path := os.Getenv("DATABASE_PATH"); path != "" {
config.Path = path
}
if url := os.Getenv("DATABASE_URL"); url != "" {
config.DatabaseURL = url
}
if poolSize := os.Getenv("DATABASE_POOL_SIZE"); poolSize != "" {
if size, err := parsePositiveInt(poolSize); err == nil {
config.MaxOpenConns = size
} else {
log.Printf("Warning: invalid DATABASE_POOL_SIZE '%s': %v, using default %d", poolSize, err, config.MaxOpenConns)
}
}
if idleConns := os.Getenv("DATABASE_MAX_IDLE_CONNS"); idleConns != "" {
if size, err := parsePositiveInt(idleConns); err == nil {
config.MaxIdleConns = size
} else {
log.Printf("Warning: invalid DATABASE_MAX_IDLE_CONNS '%s': %v, using default %d", idleConns, err, config.MaxIdleConns)
}
}
if lifetime := os.Getenv("DATABASE_CONN_MAX_LIFETIME"); lifetime != "" {
if duration, err := time.ParseDuration(lifetime); err == nil {
config.ConnMaxLifetime = duration
} else {
log.Printf("Warning: invalid DATABASE_CONN_MAX_LIFETIME '%s': %v, using default %v", lifetime, err, config.ConnMaxLifetime)
}
}
return config
}
// parsePositiveInt parses a string as a positive integer.
func parsePositiveInt(s string) (int, error) {
var i int
_, err := fmt.Sscanf(s, "%d", &i)
if err != nil || i <= 0 {
return 0, fmt.Errorf("invalid positive integer: %s", s)
}
return i, nil
}
// NewFromConfig creates a new database connection based on the configuration.
func NewFromConfig(config FullConfig) (*DB, error) {
switch config.Driver {
case DriverSQLite:
return newSQLiteDB(config)
case DriverPostgres:
return newPostgresDB(config)
default:
return nil, fmt.Errorf("unsupported database driver: %s", config.Driver)
}
}
// newSQLiteDB creates a new SQLite database connection.
func newSQLiteDB(config FullConfig) (*DB, error) {
// Ensure database directory exists (skip for in-memory databases)
if config.Path != ":memory:" {
if err := ensureDirExists(filepath.Dir(config.Path)); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
}
// Open connection
db, err := sql.Open("sqlite3", config.Path+"?_journal=WAL&_foreign_keys=on")
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Configure connection pool
// Special case: in-memory SQLite databases are per-connection. Use a single connection
// to ensure schema and data are visible across queries within the same *sql.DB handle.
if config.Path == ":memory:" {
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
} else {
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
}
db.SetConnMaxLifetime(config.ConnMaxLifetime)
// Test the connection
if err := db.Ping(); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to ping SQLite database: %w", err)
}
// Run database migrations
if err := runMigrationsForDriver(db, "sqlite3"); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to run SQLite migrations: %w", err)
}
return &DB{db: db, driver: DriverSQLite}, nil
}
// runMigrationsForDriver runs database migrations for the specified driver.
func runMigrationsForDriver(db *sql.DB, dialect string) error {
migrationsPath, err := getMigrationsPathForDialect(dialect)
if err != nil {
return fmt.Errorf("failed to get migrations path: %w", err)
}
runner := migrations.NewMigrationRunner(db, migrationsPath)
if err := runner.Up(); err != nil {
return fmt.Errorf("failed to apply migrations: %w", err)
}
return nil
}
// getMigrationsPathForDialect returns the path to migrations for the specified dialect.
// It looks for dialect-specific migrations first (e.g., sql/postgres/), then falls back
// to the common migrations directory (sql/).
func getMigrationsPathForDialect(dialect string) (string, error) {
// Common base paths to try
basePaths := []string{
"internal/database/migrations",
}
// Add path relative to this source file (for tests)
_, filename, _, ok := runtime.Caller(0)
if ok {
sourceDir := filepath.Dir(filename)
basePaths = append(basePaths, filepath.Join(sourceDir, "migrations"))
}
// Add paths relative to executable
execPath, err := os.Executable()
if err == nil {
execDir := filepath.Dir(execPath)
basePaths = append(basePaths, filepath.Join(execDir, "internal/database/migrations"))
basePaths = append(basePaths, filepath.Join(filepath.Dir(execDir), "internal/database/migrations"))
}
// Dialect-specific subdirectory (e.g., "postgres", "sqlite3")
dialectDir := dialect
if dialect == "sqlite3" {
dialectDir = "sqlite"
}
// Try each base path
for _, basePath := range basePaths {
// First, try dialect-specific directory
dialectPath := filepath.Join(basePath, "sql", dialectDir)
if _, err := os.Stat(dialectPath); err == nil {
return dialectPath, nil
}
// Fall back to common SQL directory
commonPath := filepath.Join(basePath, "sql")
if _, err := os.Stat(commonPath); err == nil {
return commonPath, nil
}
}
return "", fmt.Errorf("migrations directory not found for dialect: %s", dialect)
}
//go:build !postgres
package database
import "fmt"
// newPostgresDB is a stub that returns an error when PostgreSQL support
// is not compiled in. The real implementation is in factory_postgres.go
// and requires the 'postgres' build tag.
//
// To enable PostgreSQL support, build with: go build -tags postgres ./...
func newPostgresDB(_ FullConfig) (*DB, error) {
return nil, fmt.Errorf("PostgreSQL support not compiled in; build with -tags postgres to enable")
}
//go:build !postgres
package migrations
import "fmt"
// acquirePostgresLock is a stub that returns an error when PostgreSQL support
// is not compiled in. The real implementation is in postgres_lock.go and requires
// the 'postgres' build tag.
//
// PostgreSQL locking will be tested via Docker Compose integration tests (issue #139).
func (m *MigrationRunner) acquirePostgresLock() (func(), error) {
return nil, fmt.Errorf("PostgreSQL advisory locking requires the 'postgres' build tag; see issue #139 for Docker Compose integration tests")
}
// Package migrations provides database migration functionality using goose.
package migrations
import (
"database/sql"
"fmt"
"os"
"strings"
"time"
"github.com/pressly/goose/v3"
)
// MigrationRunner manages database migrations using goose.
type MigrationRunner struct {
db *sql.DB
migrationsPath string
}
// NewMigrationRunner creates a new migration runner.
// db is the database connection, migrationsPath is the directory containing SQL migration files.
func NewMigrationRunner(db *sql.DB, migrationsPath string) *MigrationRunner {
return &MigrationRunner{
db: db,
migrationsPath: migrationsPath,
}
}
// Up applies all pending migrations.
// Each migration runs in a transaction and will be rolled back if it fails.
// Advisory locking is used to prevent concurrent migrations in distributed systems.
func (m *MigrationRunner) Up() error {
if m.db == nil {
return fmt.Errorf("database connection is nil")
}
if m.migrationsPath == "" {
return fmt.Errorf("migrations path is empty")
}
// Acquire advisory lock to prevent concurrent migrations
release, err := m.acquireMigrationLock()
if err != nil {
return fmt.Errorf("failed to acquire migration lock: %w", err)
}
defer release()
// Detect database driver and set goose dialect
driverName, err := m.detectDriver()
if err != nil {
return fmt.Errorf("failed to detect database driver: %w", err)
}
if err := goose.SetDialect(driverName); err != nil {
return fmt.Errorf("failed to set goose dialect: %w", err)
}
// Run migrations
if err := goose.Up(m.db, m.migrationsPath); err != nil {
return fmt.Errorf("failed to apply migrations: %w", err)
}
return nil
}
// Down rolls back the most recently applied migration.
// The rollback runs in a transaction.
func (m *MigrationRunner) Down() error {
if m.db == nil {
return fmt.Errorf("database connection is nil")
}
if m.migrationsPath == "" {
return fmt.Errorf("migrations path is empty")
}
// Acquire migration lock to prevent concurrent operations
release, err := m.acquireMigrationLock()
if err != nil {
return fmt.Errorf("failed to acquire migration lock: %w", err)
}
defer release()
// Detect database driver and set goose dialect
driverName, err := m.detectDriver()
if err != nil {
return fmt.Errorf("failed to detect database driver: %w", err)
}
if err := goose.SetDialect(driverName); err != nil {
return fmt.Errorf("failed to set goose dialect: %w", err)
}
// Roll back one migration
if err := goose.Down(m.db, m.migrationsPath); err != nil {
return fmt.Errorf("failed to roll back migration: %w", err)
}
return nil
}
// Status returns the current migration version.
// Returns 0 if no migrations have been applied.
func (m *MigrationRunner) Status() (int64, error) {
if m.db == nil {
return 0, fmt.Errorf("database connection is nil")
}
if m.migrationsPath == "" {
return 0, fmt.Errorf("migrations path is empty")
}
// Detect database driver and set goose dialect
driverName, err := m.detectDriver()
if err != nil {
return 0, fmt.Errorf("failed to detect database driver: %w", err)
}
if err := goose.SetDialect(driverName); err != nil {
return 0, fmt.Errorf("failed to set goose dialect: %w", err)
}
// Get current version
version, err := goose.GetDBVersion(m.db)
if err != nil {
return 0, fmt.Errorf("failed to get migration version: %w", err)
}
return version, nil
}
// Version is an alias for Status(). Returns the current migration version.
func (m *MigrationRunner) Version() (int64, error) {
return m.Status()
}
// detectDriver detects the database driver from the connection.
// Returns the goose dialect name: "sqlite3" or "postgres".
func (m *MigrationRunner) detectDriver() (string, error) {
if m.db == nil {
return "", fmt.Errorf("database connection is nil")
}
// Get driver name from connection
driverName := m.db.Driver()
if driverName == nil {
return "", fmt.Errorf("driver is nil")
}
driverType := fmt.Sprintf("%T", driverName)
// Detect SQLite
if driverType == "*sqlite3.SQLiteDriver" || driverType == "*sqlite3.SQLiteConn" {
return "sqlite3", nil
}
// Detect PostgreSQL (common drivers: lib/pq, pgx, pgx/stdlib)
// pgx/stdlib registers as *stdlib.Driver when using sql.Open("pgx", ...)
if driverType == "*pq.driver" || driverType == "*pgx.Conn" || driverType == "*pgxpool.Pool" ||
driverType == "*stdlib.Driver" {
return "postgres", nil
}
// Try to detect via database-specific queries
// First, try PostgreSQL-specific query
var version string
pgErr := m.db.QueryRow("SELECT version()").Scan(&version)
if pgErr == nil && strings.HasPrefix(version, "PostgreSQL") {
// version() is PostgreSQL-specific and returns something like "PostgreSQL 15.x ..."
return "postgres", nil
}
// Try SQLite-specific pragma
var journalMode string
pragmaErr := m.db.QueryRow("PRAGMA journal_mode").Scan(&journalMode)
if pragmaErr == nil {
return "sqlite3", nil
}
// Default to sqlite3 for backward compatibility
return "sqlite3", nil
}
// acquireMigrationLock acquires an advisory lock to prevent concurrent migrations.
// Returns a release function that must be called to release the lock.
func (m *MigrationRunner) acquireMigrationLock() (func(), error) {
driverName, err := m.detectDriver()
if err != nil {
return nil, fmt.Errorf("failed to detect driver for locking: %w", err)
}
switch driverName {
case "sqlite3":
return m.acquireSQLiteLock()
case "postgres":
return m.acquirePostgresLock()
default:
return m.acquireSQLiteLock() // Default to SQLite lock for backward compatibility
}
}
// acquireSQLiteLock acquires a lock using a SQLite lock table.
// This prevents concurrent migrations when multiple instances start simultaneously.
func (m *MigrationRunner) acquireSQLiteLock() (func(), error) {
// Create lock table if it doesn't exist
_, err := m.db.Exec(`
CREATE TABLE IF NOT EXISTS migration_lock (
id INTEGER PRIMARY KEY CHECK (id = 1),
locked BOOLEAN NOT NULL DEFAULT 0,
locked_at DATETIME,
locked_by TEXT,
process_id INTEGER
)
`)
if err != nil {
return nil, fmt.Errorf("failed to create lock table: %w", err)
}
// Initialize lock row if it doesn't exist
_, _ = m.db.Exec(`INSERT OR IGNORE INTO migration_lock (id, locked) VALUES (1, 0)`)
// Try to acquire lock with retries
maxRetries := 10
retryDelay := 100 * time.Millisecond
processID := os.Getpid()
for i := 0; i < maxRetries; i++ {
// Use a transaction to atomically check and acquire lock
tx, err := m.db.Begin()
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
var locked bool
err = tx.QueryRow(`SELECT locked FROM migration_lock WHERE id = 1`).Scan(&locked)
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("failed to read lock status: %w", err)
}
if locked {
_ = tx.Rollback()
if i < maxRetries-1 {
time.Sleep(retryDelay)
continue
}
return nil, fmt.Errorf("migration lock is already held by another process (retried %d times)", maxRetries)
}
// Acquire lock
result, err := tx.Exec(`
UPDATE migration_lock
SET locked = 1, locked_at = CURRENT_TIMESTAMP, locked_by = ?, process_id = ?
WHERE id = 1 AND locked = 0
`, fmt.Sprintf("pid-%d", processID), processID)
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("failed to acquire lock: %w", err)
}
// Verify that the update actually affected a row (lock was acquired)
rowsAffected, err := result.RowsAffected()
if err != nil || rowsAffected == 0 {
_ = tx.Rollback()
if i < maxRetries-1 {
time.Sleep(retryDelay)
continue
}
return nil, fmt.Errorf("failed to acquire lock: another process may have acquired it")
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit lock acquisition: %w", err)
}
// Verify lock was acquired
var isLocked bool
err = m.db.QueryRow(`SELECT locked FROM migration_lock WHERE id = 1`).Scan(&isLocked)
if err != nil || !isLocked {
if i < maxRetries-1 {
time.Sleep(retryDelay)
continue
}
return nil, fmt.Errorf("failed to verify lock acquisition")
}
// Return release function
release := func() {
_, _ = m.db.Exec(`UPDATE migration_lock SET locked = 0 WHERE id = 1`)
}
return release, nil
}
return nil, fmt.Errorf("failed to acquire migration lock after %d retries", maxRetries)
}
// NOTE: acquirePostgresLock is defined in postgres_lock.go (with postgres build tag)
// and postgres_lock_stub.go (without postgres build tag). This allows PostgreSQL-specific
// code to be excluded from default coverage calculations. PostgreSQL integration tests
// will be added via Docker Compose in issue #139.
package database
import (
"context"
"errors"
"sync"
"github.com/sofatutor/llm-proxy/internal/proxy"
)
// MockProjectStore is an in-memory implementation of ProjectStore for testing and development
type MockProjectStore struct {
projects map[string]Project
apiKeys map[string]string // Project ID -> API Key mapping
mutex sync.RWMutex
}
// NewMockProjectStore creates a new MockProjectStore
func NewMockProjectStore() *MockProjectStore {
return &MockProjectStore{
projects: make(map[string]Project),
apiKeys: make(map[string]string),
}
}
// CreateProject creates a new project in the store
func (m *MockProjectStore) DBCreateProject(ctx context.Context, project Project) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.projects[project.ID]; exists {
return errors.New("project already exists")
}
m.projects[project.ID] = project
m.apiKeys[project.ID] = project.OpenAIAPIKey
return nil
}
// GetProjectByID retrieves a project by ID
func (m *MockProjectStore) DBGetProjectByID(ctx context.Context, projectID string) (Project, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
project, exists := m.projects[projectID]
if !exists {
return Project{}, errors.New("project not found")
}
return project, nil
}
// UpdateProject updates a project in the store
func (m *MockProjectStore) DBUpdateProject(ctx context.Context, project Project) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.projects[project.ID]; !exists {
return errors.New("project not found")
}
m.projects[project.ID] = project
m.apiKeys[project.ID] = project.OpenAIAPIKey
return nil
}
// DeleteProject deletes a project from the store
func (m *MockProjectStore) DBDeleteProject(ctx context.Context, projectID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.projects[projectID]; !exists {
return errors.New("project not found")
}
delete(m.projects, projectID)
delete(m.apiKeys, projectID)
return nil
}
// ListProjects retrieves all projects from the store
func (m *MockProjectStore) DBListProjects(ctx context.Context) ([]Project, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
projects := make([]Project, 0, len(m.projects))
for _, p := range m.projects {
projects = append(projects, p)
}
return projects, nil
}
// CreateMockProject creates a new project in the mock store with the given parameters
func (m *MockProjectStore) CreateMockProject(projectID, name, apiKey string) (Project, error) {
if projectID == "" {
return Project{}, errors.New("project ID cannot be empty")
}
if name == "" {
return Project{}, errors.New("project name cannot be empty")
}
if apiKey == "" {
return Project{}, errors.New("API key cannot be empty")
}
project := Project{
ID: projectID,
Name: name,
OpenAIAPIKey: apiKey,
}
err := m.DBCreateProject(context.Background(), project)
return project, err
}
// GetAPIKeyForProject retrieves the API key for a project
func (m *MockProjectStore) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
apiKey, exists := m.apiKeys[projectID]
if !exists {
return "", errors.New("project not found")
}
return apiKey, nil
}
// --- proxy.ProjectStore interface adapters ---
func (m *MockProjectStore) ListProjects(ctx context.Context) ([]proxy.Project, error) {
dbProjects, err := m.DBListProjects(ctx)
if err != nil {
return nil, err
}
var out []proxy.Project
for _, p := range dbProjects {
out = append(out, ToProxyProject(p))
}
return out, nil
}
func (m *MockProjectStore) CreateProject(ctx context.Context, p proxy.Project) error {
return m.DBCreateProject(ctx, ToDBProject(p))
}
func (m *MockProjectStore) GetProjectByID(ctx context.Context, id string) (proxy.Project, error) {
dbP, err := m.DBGetProjectByID(ctx, id)
if err != nil {
return proxy.Project{}, err
}
return ToProxyProject(dbP), nil
}
func (m *MockProjectStore) UpdateProject(ctx context.Context, p proxy.Project) error {
return m.DBUpdateProject(ctx, ToDBProject(p))
}
func (m *MockProjectStore) DeleteProject(ctx context.Context, id string) error {
return m.DBDeleteProject(ctx, id)
}
// GetProjectActive retrieves the active status for a project by ID
func (m *MockProjectStore) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
project, exists := m.projects[projectID]
if !exists {
return false, errors.New("project not found")
}
return project.IsActive, nil
}
package database
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/sofatutor/llm-proxy/internal/token"
)
// MockTokenStore is an in-memory implementation of TokenStore for testing and development
type MockTokenStore struct {
tokens map[string]Token
mutex sync.RWMutex
projectIDs map[string]string // token -> projectID mapping for quick lookup
}
// NewMockTokenStore creates a new MockTokenStore
func NewMockTokenStore() *MockTokenStore {
return &MockTokenStore{
tokens: make(map[string]Token),
projectIDs: make(map[string]string),
}
}
// CreateToken creates a new token in the store
func (m *MockTokenStore) CreateToken(ctx context.Context, token Token) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.tokens[token.Token]; exists {
return ErrTokenExists
}
m.tokens[token.Token] = token
m.projectIDs[token.Token] = token.ProjectID
return nil
}
// GetTokenByID retrieves a token by ID
func (m *MockTokenStore) GetTokenByID(ctx context.Context, tokenID string) (Token, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
token, exists := m.tokens[tokenID]
if !exists {
return Token{}, ErrTokenNotFound
}
return token, nil
}
// UpdateToken updates a token in the store
func (m *MockTokenStore) UpdateToken(ctx context.Context, token Token) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.tokens[token.Token]; !exists {
return ErrTokenNotFound
}
m.tokens[token.Token] = token
m.projectIDs[token.Token] = token.ProjectID
return nil
}
// DeleteToken deletes a token from the store
func (m *MockTokenStore) DeleteToken(ctx context.Context, tokenID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.tokens[tokenID]; !exists {
return ErrTokenNotFound
}
delete(m.tokens, tokenID)
delete(m.projectIDs, tokenID)
return nil
}
// ListTokens retrieves all tokens from the store
func (m *MockTokenStore) ListTokens(ctx context.Context) ([]Token, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
tokens := make([]Token, 0, len(m.tokens))
for _, t := range m.tokens {
tokens = append(tokens, t)
}
return tokens, nil
}
// GetTokensByProjectID retrieves all tokens for a project
func (m *MockTokenStore) GetTokensByProjectID(ctx context.Context, projectID string) ([]Token, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
var tokens []Token
for _, t := range m.tokens {
if t.ProjectID == projectID {
tokens = append(tokens, t)
}
}
return tokens, nil
}
// IncrementTokenUsage increments the request count and updates the last_used_at timestamp
func (m *MockTokenStore) IncrementTokenUsage(ctx context.Context, tokenID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
t, exists := m.tokens[tokenID]
if !exists {
return ErrTokenNotFound
}
t.RequestCount++
now := time.Now()
t.LastUsedAt = &now
m.tokens[tokenID] = t
return nil
}
// CleanExpiredTokens deletes expired tokens from the store
func (m *MockTokenStore) CleanExpiredTokens(ctx context.Context) (int64, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
var count int64
for id, t := range m.tokens {
if t.ExpiresAt != nil && t.ExpiresAt.Before(now) {
delete(m.tokens, id)
delete(m.projectIDs, id)
count++
}
}
return count, nil
}
// CreateMockToken creates a new token in the mock store with the given parameters
func (m *MockTokenStore) CreateMockToken(tokenID, projectID string, expiresIn time.Duration, isActive bool, maxRequests *int) (Token, error) {
if tokenID == "" {
return Token{}, errors.New("token ID cannot be empty")
}
if projectID == "" {
return Token{}, errors.New("project ID cannot be empty")
}
var expiresAt *time.Time
if expiresIn > 0 {
expiry := time.Now().Add(expiresIn)
expiresAt = &expiry
}
now := time.Now()
token := Token{
Token: tokenID,
ProjectID: projectID,
ExpiresAt: expiresAt,
IsActive: isActive,
RequestCount: 0,
MaxRequests: maxRequests,
CreatedAt: now,
}
err := m.CreateToken(context.Background(), token)
return token, err
}
// TokenStoreAdapter adapts the database.DB to the token.TokenStore interface
type TokenStoreAdapter struct {
store *MockTokenStore
}
// NewTokenStoreAdapter creates a new TokenStoreAdapter
func NewTokenStoreAdapter(store *MockTokenStore) *TokenStoreAdapter {
return &TokenStoreAdapter{
store: store,
}
}
// GetTokenByID retrieves a token by ID
func (a *TokenStoreAdapter) GetTokenByID(ctx context.Context, tokenID string) (token.TokenData, error) {
t, err := a.store.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return token.TokenData{}, token.ErrTokenNotFound
}
return token.TokenData{}, err
}
return ExportTokenData(t), nil
}
// IncrementTokenUsage increments the request count and updates the last_used_at timestamp
func (a *TokenStoreAdapter) IncrementTokenUsage(ctx context.Context, tokenID string) error {
err := a.store.IncrementTokenUsage(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return token.ErrTokenNotFound
}
return err
}
return nil
}
// CreateToken creates a new token in the store
func (a *TokenStoreAdapter) CreateToken(ctx context.Context, td token.TokenData) error {
dbToken := ImportTokenData(td)
return a.store.CreateToken(ctx, dbToken)
}
// ListTokens retrieves all tokens from the store
func (a *TokenStoreAdapter) ListTokens(ctx context.Context) ([]token.TokenData, error) {
dbTokens, err := a.store.ListTokens(ctx)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
// GetTokensByProjectID retrieves all tokens for a project
func (a *TokenStoreAdapter) GetTokensByProjectID(ctx context.Context, projectID string) ([]token.TokenData, error) {
dbTokens, err := a.store.GetTokensByProjectID(ctx, projectID)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
func TestMockTokenStore_EdgeCases(t *testing.T) {
store := NewMockTokenStore()
ctx := context.Background()
t.Run("GetTokenByID not found", func(t *testing.T) {
_, err := store.GetTokenByID(ctx, "notfound")
if err == nil {
t.Error("expected error for notfound token")
}
})
t.Run("IncrementTokenUsage not found", func(t *testing.T) {
err := store.IncrementTokenUsage(ctx, "notfound")
if err == nil {
t.Error("expected error for notfound token")
}
})
t.Run("ListTokens empty", func(t *testing.T) {
ts, err := store.ListTokens(ctx)
if err != nil {
t.Error(err)
}
if len(ts) != 0 {
t.Errorf("expected 0 tokens, got %d", len(ts))
}
})
t.Run("GetTokensByProjectID empty", func(t *testing.T) {
ts, err := store.GetTokensByProjectID(ctx, "pid")
if err != nil {
t.Error(err)
}
if len(ts) != 0 {
t.Errorf("expected 0 tokens, got %d", len(ts))
}
})
}
package database
import (
"time"
)
// Project represents a project in the database.
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
OpenAIAPIKey string `json:"-"` // Sensitive data, not included in JSON
IsActive bool `json:"is_active"`
DeactivatedAt *time.Time `json:"deactivated_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Token represents a token in the database.
type Token struct {
Token string `json:"token"`
ProjectID string `json:"project_id"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IsActive bool `json:"is_active"`
DeactivatedAt *time.Time `json:"deactivated_at,omitempty"`
RequestCount int `json:"request_count"`
MaxRequests *int `json:"max_requests,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
CacheHitCount int `json:"cache_hit_count"`
}
// AuditEvent represents an audit log entry in the database.
type AuditEvent struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"`
Actor string `json:"actor"`
ProjectID *string `json:"project_id,omitempty"`
RequestID *string `json:"request_id,omitempty"`
CorrelationID *string `json:"correlation_id,omitempty"`
ClientIP *string `json:"client_ip,omitempty"`
Method *string `json:"method,omitempty"`
Path *string `json:"path,omitempty"`
UserAgent *string `json:"user_agent,omitempty"`
Outcome string `json:"outcome"`
Reason *string `json:"reason,omitempty"`
TokenID *string `json:"token_id,omitempty"`
Metadata *string `json:"metadata,omitempty"` // JSON string
}
// IsExpired returns true if the token has expired.
func (t *Token) IsExpired() bool {
if t.ExpiresAt == nil {
return false
}
return time.Now().After(*t.ExpiresAt)
}
// IsRateLimited returns true if the token has reached its maximum number of requests.
func (t *Token) IsRateLimited() bool {
if t.MaxRequests == nil {
return false
}
return t.RequestCount >= *t.MaxRequests
}
// IsValid returns true if the token is active, not expired, and not rate limited.
func (t *Token) IsValid() bool {
return t.IsActive && !t.IsExpired() && !t.IsRateLimited()
}
// IsDeactivated returns true if the token has been explicitly deactivated.
func (t *Token) IsDeactivated() bool {
return t.DeactivatedAt != nil
}
// IsDeactivated returns true if the project has been explicitly deactivated.
func (p *Project) IsDeactivated() bool {
return p.DeactivatedAt != nil
}
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/sofatutor/llm-proxy/internal/proxy"
)
var (
// ErrProjectNotFound is returned when a project is not found.
ErrProjectNotFound = errors.New("project not found")
// ErrProjectExists is returned when a project already exists.
ErrProjectExists = errors.New("project already exists")
)
// GetProjectByName retrieves a project by name.
func (d *DB) GetProjectByName(ctx context.Context, name string) (Project, error) {
query := `
SELECT id, name, openai_api_key, is_active, deactivated_at, created_at, updated_at
FROM projects
WHERE name = ?
`
var project Project
var deactivatedAt sql.NullTime
err := d.QueryRowContextRebound(ctx, query, name).Scan(
&project.ID,
&project.Name,
&project.OpenAIAPIKey,
&project.IsActive,
&deactivatedAt,
&project.CreatedAt,
&project.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Project{}, ErrProjectNotFound
}
return Project{}, fmt.Errorf("failed to get project: %w", err)
}
if deactivatedAt.Valid {
project.DeactivatedAt = &deactivatedAt.Time
}
return project, nil
}
// ToProxyProject converts a database.Project to a proxy.Project
func ToProxyProject(dbProject Project) proxy.Project {
return proxy.Project{
ID: dbProject.ID,
Name: dbProject.Name,
OpenAIAPIKey: dbProject.OpenAIAPIKey,
IsActive: dbProject.IsActive,
DeactivatedAt: dbProject.DeactivatedAt,
CreatedAt: dbProject.CreatedAt,
UpdatedAt: dbProject.UpdatedAt,
}
}
// ToDBProject converts a proxy.Project to a database.Project
func ToDBProject(proxyProject proxy.Project) Project {
return Project{
ID: proxyProject.ID,
Name: proxyProject.Name,
OpenAIAPIKey: proxyProject.OpenAIAPIKey,
IsActive: proxyProject.IsActive,
DeactivatedAt: proxyProject.DeactivatedAt,
CreatedAt: proxyProject.CreatedAt,
UpdatedAt: proxyProject.UpdatedAt,
}
}
// Rename CRUD methods for DB store
func (d *DB) DBListProjects(ctx context.Context) ([]Project, error) {
query := `
SELECT id, name, openai_api_key, is_active, deactivated_at, created_at, updated_at
FROM projects
ORDER BY name ASC
`
rows, err := d.QueryContextRebound(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list projects: %w", err)
}
defer func() {
_ = rows.Close()
}()
var projects []Project
for rows.Next() {
var project Project
var deactivatedAt sql.NullTime
if err := rows.Scan(
&project.ID,
&project.Name,
&project.OpenAIAPIKey,
&project.IsActive,
&deactivatedAt,
&project.CreatedAt,
&project.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan project: %w", err)
}
if deactivatedAt.Valid {
project.DeactivatedAt = &deactivatedAt.Time
}
projects = append(projects, project)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating projects: %w", err)
}
return projects, nil
}
func (d *DB) DBCreateProject(ctx context.Context, project Project) error {
query := `
INSERT INTO projects (id, name, openai_api_key, is_active, deactivated_at, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
_, err := d.ExecContextRebound(
ctx,
query,
project.ID,
project.Name,
project.OpenAIAPIKey,
project.IsActive,
project.DeactivatedAt,
project.CreatedAt,
project.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create project: %w", err)
}
return nil
}
func (d *DB) DBGetProjectByID(ctx context.Context, projectID string) (Project, error) {
query := `
SELECT id, name, openai_api_key, is_active, deactivated_at, created_at, updated_at
FROM projects
WHERE id = ?
`
var project Project
var deactivatedAt sql.NullTime
err := d.QueryRowContextRebound(ctx, query, projectID).Scan(
&project.ID,
&project.Name,
&project.OpenAIAPIKey,
&project.IsActive,
&deactivatedAt,
&project.CreatedAt,
&project.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Project{}, ErrProjectNotFound
}
return Project{}, fmt.Errorf("failed to get project: %w", err)
}
if deactivatedAt.Valid {
project.DeactivatedAt = &deactivatedAt.Time
}
return project, nil
}
func (d *DB) DBUpdateProject(ctx context.Context, project Project) error {
project.UpdatedAt = time.Now()
query := `
UPDATE projects
SET name = ?, openai_api_key = ?, is_active = ?, deactivated_at = ?, updated_at = ?
WHERE id = ?
`
result, err := d.ExecContextRebound(
ctx,
query,
project.Name,
project.OpenAIAPIKey,
project.IsActive,
project.DeactivatedAt,
project.UpdatedAt,
project.ID,
)
if err != nil {
return fmt.Errorf("failed to update project: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrProjectNotFound
}
return nil
}
func (d *DB) DBDeleteProject(ctx context.Context, projectID string) error {
query := `
DELETE FROM projects
WHERE id = ?
`
result, err := d.ExecContextRebound(ctx, query, projectID)
if err != nil {
return fmt.Errorf("failed to delete project: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrProjectNotFound
}
return nil
}
// --- proxy.ProjectStore interface adapters ---
func (d *DB) ListProjects(ctx context.Context) ([]proxy.Project, error) {
dbProjects, err := d.DBListProjects(ctx)
if err != nil {
return nil, err
}
var out []proxy.Project
for _, p := range dbProjects {
out = append(out, ToProxyProject(p))
}
return out, nil
}
func (d *DB) CreateProject(ctx context.Context, p proxy.Project) error {
return d.DBCreateProject(ctx, ToDBProject(p))
}
func (d *DB) GetProjectByID(ctx context.Context, id string) (proxy.Project, error) {
dbP, err := d.DBGetProjectByID(ctx, id)
if err != nil {
return proxy.Project{}, err
}
return ToProxyProject(dbP), nil
}
func (d *DB) UpdateProject(ctx context.Context, p proxy.Project) error {
return d.DBUpdateProject(ctx, ToDBProject(p))
}
func (d *DB) DeleteProject(ctx context.Context, id string) error {
return d.DBDeleteProject(ctx, id)
}
// GetAPIKeyForProject retrieves the OpenAI API key for a project by ID
func (d *DB) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
query := `SELECT openai_api_key FROM projects WHERE id = ?`
var apiKey string
err := d.QueryRowContextRebound(ctx, query, projectID).Scan(&apiKey)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", ErrProjectNotFound
}
return "", fmt.Errorf("failed to get API key for project: %w", err)
}
return apiKey, nil
}
// GetProjectActive retrieves the active status for a project by ID
func (d *DB) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
query := `SELECT is_active FROM projects WHERE id = ?`
var isActive bool
err := d.QueryRowContextRebound(ctx, query, projectID).Scan(&isActive)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, ErrProjectNotFound
}
return false, fmt.Errorf("failed to get project active status: %w", err)
}
return isActive, nil
}
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"github.com/sofatutor/llm-proxy/internal/token"
)
var (
// ErrTokenNotFound is returned when a token is not found.
ErrTokenNotFound = errors.New("token not found")
// ErrTokenExists is returned when a token already exists.
ErrTokenExists = errors.New("token already exists")
)
// CreateToken creates a new token in the database.
func (d *DB) CreateToken(ctx context.Context, token Token) error {
query := `
INSERT INTO tokens (token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := d.ExecContextRebound(
ctx,
query,
token.Token,
token.ProjectID,
token.ExpiresAt,
token.IsActive,
nil,
token.RequestCount,
token.MaxRequests,
token.CreatedAt,
token.LastUsedAt,
)
if err != nil {
return fmt.Errorf("failed to create token: %w", err)
}
return nil
}
// GetTokenByID retrieves a token by ID.
func (d *DB) GetTokenByID(ctx context.Context, tokenID string) (Token, error) {
query := `
SELECT token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at, cache_hit_count
FROM tokens
WHERE token = ?
`
var token Token
var expiresAt, lastUsedAt, deactivatedAt sql.NullTime
var maxRequests sql.NullInt32
err := d.QueryRowContextRebound(ctx, query, tokenID).Scan(
&token.Token,
&token.ProjectID,
&expiresAt,
&token.IsActive,
&deactivatedAt,
&token.RequestCount,
&maxRequests,
&token.CreatedAt,
&lastUsedAt,
&token.CacheHitCount,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Token{}, ErrTokenNotFound
}
return Token{}, fmt.Errorf("failed to get token: %w", err)
}
if expiresAt.Valid {
token.ExpiresAt = &expiresAt.Time
}
if lastUsedAt.Valid {
token.LastUsedAt = &lastUsedAt.Time
}
if deactivatedAt.Valid {
token.DeactivatedAt = &deactivatedAt.Time
}
if maxRequests.Valid {
maxReq := int(maxRequests.Int32)
token.MaxRequests = &maxReq
}
return token, nil
}
// UpdateToken updates a token in the database.
func (d *DB) UpdateToken(ctx context.Context, token Token) error {
query := `
UPDATE tokens
SET project_id = ?, expires_at = ?, is_active = ?, request_count = ?, max_requests = ?, last_used_at = ?
WHERE token = ?
`
result, err := d.ExecContextRebound(
ctx,
query,
token.ProjectID,
token.ExpiresAt,
token.IsActive,
token.RequestCount,
token.MaxRequests,
token.LastUsedAt,
token.Token,
)
if err != nil {
return fmt.Errorf("failed to update token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrTokenNotFound
}
return nil
}
// DeleteToken deletes a token from the database.
func (d *DB) DeleteToken(ctx context.Context, tokenID string) error {
query := `
DELETE FROM tokens
WHERE token = ?
`
result, err := d.ExecContextRebound(ctx, query, tokenID)
if err != nil {
return fmt.Errorf("failed to delete token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrTokenNotFound
}
return nil
}
// ListTokens retrieves all tokens from the database.
func (d *DB) ListTokens(ctx context.Context) ([]Token, error) {
query := `
SELECT token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at, cache_hit_count
FROM tokens
ORDER BY created_at DESC
`
return d.queryTokens(ctx, query)
}
// GetTokensByProjectID retrieves all tokens for a project.
func (d *DB) GetTokensByProjectID(ctx context.Context, projectID string) ([]Token, error) {
query := `
SELECT token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at, cache_hit_count
FROM tokens
WHERE project_id = ?
ORDER BY created_at DESC
`
return d.queryTokens(ctx, query, projectID)
}
// IncrementTokenUsage increments the request count and updates the last_used_at timestamp.
func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error {
now := time.Now()
query := `
UPDATE tokens
SET request_count = request_count + 1, last_used_at = ?
WHERE token = ?
`
result, err := d.ExecContextRebound(ctx, query, now, tokenID)
if err != nil {
return fmt.Errorf("failed to increment token usage: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrTokenNotFound
}
return nil
}
// CleanExpiredTokens deletes expired tokens from the database.
func (d *DB) CleanExpiredTokens(ctx context.Context) (int64, error) {
now := time.Now()
query := `
DELETE FROM tokens
WHERE expires_at IS NOT NULL AND expires_at < ?
`
result, err := d.ExecContextRebound(ctx, query, now)
if err != nil {
return 0, fmt.Errorf("failed to clean expired tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return rowsAffected, nil
}
// queryTokens is a helper function to query tokens.
func (d *DB) queryTokens(ctx context.Context, query string, args ...interface{}) ([]Token, error) {
rows, err := d.QueryContextRebound(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query tokens: %w", err)
}
defer func() {
_ = rows.Close()
}()
var tokens []Token
for rows.Next() {
var token Token
var expiresAt, lastUsedAt, deactivatedAt sql.NullTime
var maxRequests sql.NullInt32
if err := rows.Scan(
&token.Token,
&token.ProjectID,
&expiresAt,
&token.IsActive,
&deactivatedAt,
&token.RequestCount,
&maxRequests,
&token.CreatedAt,
&lastUsedAt,
&token.CacheHitCount,
); err != nil {
return nil, fmt.Errorf("failed to scan token: %w", err)
}
if expiresAt.Valid {
token.ExpiresAt = &expiresAt.Time
}
if lastUsedAt.Valid {
token.LastUsedAt = &lastUsedAt.Time
}
if deactivatedAt.Valid {
token.DeactivatedAt = &deactivatedAt.Time
}
if maxRequests.Valid {
maxReq := int(maxRequests.Int32)
token.MaxRequests = &maxReq
}
tokens = append(tokens, token)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating tokens: %w", err)
}
return tokens, nil
}
// --- token.TokenStore interface adapter for *DB ---
type DBTokenStoreAdapter struct {
db *DB
}
func NewDBTokenStoreAdapter(db *DB) *DBTokenStoreAdapter {
return &DBTokenStoreAdapter{db: db}
}
func (a *DBTokenStoreAdapter) GetTokenByID(ctx context.Context, tokenID string) (token.TokenData, error) {
dbToken, err := a.db.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return token.TokenData{}, token.ErrTokenNotFound
}
return token.TokenData{}, err
}
return ExportTokenData(dbToken), nil
}
func (a *DBTokenStoreAdapter) IncrementTokenUsage(ctx context.Context, tokenID string) error {
return a.db.IncrementTokenUsage(ctx, tokenID)
}
func (a *DBTokenStoreAdapter) CreateToken(ctx context.Context, td token.TokenData) error {
dbToken := ImportTokenData(td)
return a.db.CreateToken(ctx, dbToken)
}
func (a *DBTokenStoreAdapter) UpdateToken(ctx context.Context, td token.TokenData) error {
dbToken := ImportTokenData(td)
return a.db.UpdateToken(ctx, dbToken)
}
func (a *DBTokenStoreAdapter) ListTokens(ctx context.Context) ([]token.TokenData, error) {
dbTokens, err := a.db.ListTokens(ctx)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
func (a *DBTokenStoreAdapter) GetTokensByProjectID(ctx context.Context, projectID string) ([]token.TokenData, error) {
dbTokens, err := a.db.GetTokensByProjectID(ctx, projectID)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
// ImportTokenData and ExportTokenData helpers
func ImportTokenData(td token.TokenData) Token {
return Token{
Token: td.Token,
ProjectID: td.ProjectID,
ExpiresAt: td.ExpiresAt,
IsActive: td.IsActive,
DeactivatedAt: td.DeactivatedAt,
RequestCount: td.RequestCount,
MaxRequests: td.MaxRequests,
CreatedAt: td.CreatedAt,
LastUsedAt: td.LastUsedAt,
CacheHitCount: td.CacheHitCount,
}
}
func ExportTokenData(t Token) token.TokenData {
return token.TokenData{
Token: t.Token,
ProjectID: t.ProjectID,
ExpiresAt: t.ExpiresAt,
IsActive: t.IsActive,
DeactivatedAt: t.DeactivatedAt,
RequestCount: t.RequestCount,
MaxRequests: t.MaxRequests,
CreatedAt: t.CreatedAt,
LastUsedAt: t.LastUsedAt,
CacheHitCount: t.CacheHitCount,
}
}
// --- RevocationStore interface implementation ---
// RevokeToken disables a token by setting is_active to false and deactivated_at to current time
func (a *DBTokenStoreAdapter) RevokeToken(ctx context.Context, tokenID string) error {
if tokenID == "" {
return token.ErrTokenNotFound
}
now := time.Now()
query := `UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE token = ? AND is_active = ?`
result, err := a.db.ExecContextRebound(ctx, query, false, now, tokenID, true)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to check rows affected: %w", err)
}
if rowsAffected == 0 {
// Check if token exists at all (could be already inactive)
var exists bool
checkQuery := `SELECT 1 FROM tokens WHERE token = ? LIMIT 1`
err = a.db.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&exists)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return token.ErrTokenNotFound
}
return fmt.Errorf("failed to check token existence: %w", err)
}
// Token exists but was already inactive - this is idempotent, no error
}
return nil
}
// DeleteToken completely removes a token from storage
func (a *DBTokenStoreAdapter) DeleteToken(ctx context.Context, tokenID string) error {
if tokenID == "" {
return token.ErrTokenNotFound
}
result, err := a.db.ExecContextRebound(ctx, "DELETE FROM tokens WHERE token = ?", tokenID)
if err != nil {
return fmt.Errorf("failed to delete token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return token.ErrTokenNotFound
}
return nil
}
// RevokeBatchTokens revokes multiple tokens at once
func (a *DBTokenStoreAdapter) RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error) {
if len(tokenIDs) == 0 {
return 0, nil
}
now := time.Now()
placeholders := make([]string, len(tokenIDs))
args := make([]interface{}, len(tokenIDs)+3)
args[0] = false
args[1] = now
for i, tokenID := range tokenIDs {
placeholders[i] = "?"
args[i+2] = tokenID
}
// Append active-state filter parameter
args[len(args)-1] = true
query := fmt.Sprintf(`UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE token IN (%s) AND is_active = ?`, strings.Join(placeholders, ","))
result, err := a.db.ExecContextRebound(ctx, query, args...)
if err != nil {
return 0, fmt.Errorf("failed to revoke batch tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// RevokeProjectTokens revokes all tokens for a project
func (a *DBTokenStoreAdapter) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
if projectID == "" {
return 0, nil
}
now := time.Now()
query := `UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE project_id = ? AND is_active = ?`
result, err := a.db.ExecContextRebound(ctx, query, false, now, projectID, true)
if err != nil {
return 0, fmt.Errorf("failed to revoke project tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// RevokeExpiredTokens revokes all tokens that have expired
func (a *DBTokenStoreAdapter) RevokeExpiredTokens(ctx context.Context) (int, error) {
now := time.Now()
query := `UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE expires_at IS NOT NULL AND expires_at < ? AND is_active = ?`
result, err := a.db.ExecContextRebound(ctx, query, false, now, now, true)
if err != nil {
return 0, fmt.Errorf("failed to revoke expired tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// IncrementCacheHitCount increments the cache_hit_count for a single token.
func (d *DB) IncrementCacheHitCount(ctx context.Context, tokenID string, delta int) error {
if delta <= 0 {
return nil
}
query := `UPDATE tokens SET cache_hit_count = cache_hit_count + ? WHERE token = ?`
_, err := d.ExecContextRebound(ctx, query, delta, tokenID)
if err != nil {
return fmt.Errorf("failed to increment cache hit count: %w", err)
}
return nil
}
// IncrementCacheHitCountBatch increments cache_hit_count for multiple tokens in batch.
// The deltas map has token IDs as keys and increment values as values.
func (d *DB) IncrementCacheHitCountBatch(ctx context.Context, deltas map[string]int) error {
if len(deltas) == 0 {
return nil
}
// Use a transaction for batch updates
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
_ = tx.Rollback() // No-op if already committed
}()
query := `UPDATE tokens SET cache_hit_count = cache_hit_count + ? WHERE token = ?`
stmt, err := tx.PrepareContext(ctx, d.RebindQuery(query))
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
defer func() {
_ = stmt.Close()
}()
for tokenID, delta := range deltas {
if delta <= 0 {
continue
}
if _, err := stmt.ExecContext(ctx, delta, tokenID); err != nil {
return fmt.Errorf("failed to increment cache hit count for token %s: %w", tokenID, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
package database
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
)
// Placeholder returns the appropriate placeholder for the driver.
// For SQLite: ?, for PostgreSQL: $1, $2, etc.
func (d *DB) Placeholder(n int) string {
if d.driver == DriverPostgres {
return fmt.Sprintf("$%d", n)
}
return "?"
}
// Placeholders returns a slice of placeholders for the driver.
// For n=3: SQLite returns ["?", "?", "?"], PostgreSQL returns ["$1", "$2", "$3"].
func (d *DB) Placeholders(n int) []string {
result := make([]string, n)
for i := 0; i < n; i++ {
result[i] = d.Placeholder(i + 1)
}
return result
}
// PlaceholderList returns a comma-separated list of placeholders.
// For n=3: SQLite returns "?, ?, ?", PostgreSQL returns "$1, $2, $3".
func (d *DB) PlaceholderList(n int) string {
return strings.Join(d.Placeholders(n), ", ")
}
// RebindQuery converts a query from ? placeholders to the appropriate
// placeholder style for the database driver.
//
// IMPORTANT: This function performs a simple character replacement and does NOT
// handle ? characters inside SQL string literals (e.g., "WHERE name = 'what?'").
// Since this codebase exclusively uses parameterized queries with ? as placeholders,
// this limitation does not affect normal usage. If you need to use literal ? in
// string values, use parameterized queries: "WHERE name = ?" with the value passed
// as an argument.
func (d *DB) RebindQuery(query string) string {
if d.driver != DriverPostgres {
return query
}
// Convert ? to $1, $2, $3, etc. (single pass for better performance)
var builder strings.Builder
builder.Grow(len(query) + 10) // pre-allocate with some buffer
count := 0
for i := 0; i < len(query); i++ {
if query[i] == '?' {
count++
builder.WriteString(fmt.Sprintf("$%d", count))
} else {
builder.WriteByte(query[i])
}
}
return builder.String()
}
// ExecContextRebound executes a query with automatic placeholder rebinding.
func (d *DB) ExecContextRebound(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return d.db.ExecContext(ctx, d.RebindQuery(query), args...)
}
// QueryRowContextRebound queries a single row with automatic placeholder rebinding.
func (d *DB) QueryRowContextRebound(ctx context.Context, query string, args ...interface{}) *sql.Row {
return d.db.QueryRowContext(ctx, d.RebindQuery(query), args...)
}
// QueryContextRebound queries multiple rows with automatic placeholder rebinding.
func (d *DB) QueryContextRebound(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return d.db.QueryContext(ctx, d.RebindQuery(query), args...)
}
// BackupDatabase creates a backup of the database.
// Note: This function is SQLite-specific. For PostgreSQL, use pg_dump.
func (d *DB) BackupDatabase(ctx context.Context, backupPath string) error {
if d.driver == DriverPostgres {
return fmt.Errorf("backup not supported for PostgreSQL via this method; use pg_dump")
}
// Validate the backupPath to ensure it is a valid file path
if backupPath == "" {
return fmt.Errorf("backup path cannot be empty")
}
// SQLite does not support parameterized VACUUM INTO, so we must sanitize the path
// Only allow simple file paths (no semicolons, no SQL metacharacters)
if len(backupPath) > 256 || backupPath[0] == '-' || backupPath[0] == '|' || backupPath[0] == ';' {
return fmt.Errorf("invalid backup path")
}
// For SQLite, we can use the VACUUM INTO statement to create a backup
query := fmt.Sprintf("VACUUM INTO '%s'", backupPath)
_, err := d.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to backup database: %w", err)
}
return nil
}
// MaintainDatabase performs regular maintenance on the database.
// WARNING: VACUUM and ANALYZE can be expensive operations. In production, schedule this function to run periodically (e.g., daily) rather than on every call.
// The caller is responsible for scheduling.
func (d *DB) MaintainDatabase(ctx context.Context) error {
if d.driver == DriverPostgres {
// PostgreSQL uses VACUUM ANALYZE
_, err := d.db.ExecContext(ctx, "VACUUM ANALYZE")
if err != nil {
return fmt.Errorf("failed to vacuum analyze database: %w", err)
}
return nil
}
// SQLite-specific maintenance
// Run VACUUM to reclaim space and optimize the database
_, err := d.db.ExecContext(ctx, "VACUUM")
if err != nil {
return fmt.Errorf("failed to vacuum database: %w", err)
}
// Run PRAGMA optimize to optimize the database
_, err = d.db.ExecContext(ctx, "PRAGMA optimize")
if err != nil {
return fmt.Errorf("failed to optimize database: %w", err)
}
// Run ANALYZE to update statistics
_, err = d.db.ExecContext(ctx, "ANALYZE")
if err != nil {
return fmt.Errorf("failed to analyze database: %w", err)
}
return nil
}
// boolValue returns the appropriate boolean representation for the driver.
// SQLite uses 1/0, PostgreSQL uses true/false.
func (d *DB) boolValue(b bool) interface{} {
if d.driver == DriverPostgres {
return b
}
if b {
return 1
}
return 0
}
// GetStats returns database statistics.
func (d *DB) GetStats(ctx context.Context) (map[string]interface{}, error) {
stats := make(map[string]interface{})
// Get database size (driver-specific - these queries are fundamentally different)
var dbSize int64
if d.driver == DriverPostgres {
err := d.db.QueryRowContext(ctx, "SELECT pg_database_size(current_database())").Scan(&dbSize)
if err != nil {
return nil, fmt.Errorf("failed to get database size: %w", err)
}
} else {
err := d.db.QueryRowContext(ctx, "SELECT (SELECT page_count FROM pragma_page_count) * (SELECT page_size FROM pragma_page_size)").Scan(&dbSize)
if err != nil {
return nil, fmt.Errorf("failed to get database size: %w", err)
}
}
stats["database_size_bytes"] = dbSize
// Count active projects
var projectCount int
err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM projects").Scan(&projectCount)
if err != nil {
return nil, fmt.Errorf("failed to count projects: %w", err)
}
stats["project_count"] = projectCount
// Count active tokens using QueryRowContextRebound for placeholder rebinding
var activeTokens int
err = d.QueryRowContextRebound(ctx,
"SELECT COUNT(*) FROM tokens WHERE is_active = ? AND (expires_at IS NULL OR expires_at > ?)",
d.boolValue(true), time.Now()).Scan(&activeTokens)
if err != nil {
return nil, fmt.Errorf("failed to count active tokens: %w", err)
}
stats["active_token_count"] = activeTokens
// Count expired tokens using QueryRowContextRebound for placeholder rebinding
var expiredTokens int
err = d.QueryRowContextRebound(ctx,
"SELECT COUNT(*) FROM tokens WHERE expires_at IS NOT NULL AND expires_at <= ?",
time.Now()).Scan(&expiredTokens)
if err != nil {
return nil, fmt.Errorf("failed to count expired tokens: %w", err)
}
stats["expired_token_count"] = expiredTokens
// Count total request count
var totalRequests sql.NullInt64
err = d.db.QueryRowContext(ctx, "SELECT SUM(request_count) FROM tokens").Scan(&totalRequests)
if err != nil {
return nil, fmt.Errorf("failed to sum request counts: %w", err)
}
if totalRequests.Valid {
stats["total_request_count"] = totalRequests.Int64
} else {
stats["total_request_count"] = int64(0)
}
return stats, nil
}
// IsTokenValid checks if a token is valid (exists, is active, not expired, and not rate limited).
func (d *DB) IsTokenValid(ctx context.Context, tokenID string) (bool, error) {
token, err := d.GetTokenByID(ctx, tokenID)
if err != nil {
if err == ErrTokenNotFound {
return false, nil
}
return false, err
}
return token.IsValid(), nil
}
package dispatcher
// PermanentBackendError is a custom error type for permanent backend errors (e.g., Helicone 500s that should not be retried)
type PermanentBackendError struct {
Msg string
}
func (e *PermanentBackendError) Error() string {
return e.Msg
}
package plugins
import (
"context"
"encoding/json"
"fmt"
"os"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// FilePlugin implements file-based event logging
type FilePlugin struct {
filePath string
file *os.File
}
// NewFilePlugin creates a new file plugin
func NewFilePlugin() *FilePlugin {
return &FilePlugin{}
}
// Init initializes the file plugin with configuration
func (p *FilePlugin) Init(cfg map[string]string) error {
filePath, ok := cfg["endpoint"]
if !ok || filePath == "" {
return fmt.Errorf("file plugin requires 'endpoint' configuration (file path)")
}
p.filePath = filePath
// Open file for writing
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", filePath, err)
}
p.file = file
return nil
}
// SendEvents writes events to the file as JSONL (JSON Lines)
func (p *FilePlugin) SendEvents(ctx context.Context, events []dispatcher.EventPayload) error {
if p.file == nil {
return fmt.Errorf("file plugin not initialized")
}
for _, event := range events {
line, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal event: %w", err)
}
// Write JSON line with newline
if _, err := p.file.Write(append(line, '\n')); err != nil {
return fmt.Errorf("failed to write to file: %w", err)
}
}
// Ensure data is written to disk
if err := p.file.Sync(); err != nil {
return fmt.Errorf("failed to sync file: %w", err)
}
return nil
}
// Close closes the file
func (p *FilePlugin) Close() error {
if p.file != nil {
return p.file.Close()
}
return nil
}
package plugins
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// HeliconePlugin implements Helicone backend integration
type HeliconePlugin struct {
apiKey string
endpoint string
client *http.Client
}
// NewHeliconePlugin creates a new Helicone plugin
func NewHeliconePlugin() *HeliconePlugin {
return &HeliconePlugin{
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// Init initializes the Helicone plugin with configuration
func (p *HeliconePlugin) Init(cfg map[string]string) error {
apiKey, ok := cfg["api-key"]
if !ok || apiKey == "" {
return fmt.Errorf("helicone plugin requires 'api-key' configuration")
}
endpoint, ok := cfg["endpoint"]
if !ok || endpoint == "" {
endpoint = "https://api.worker.helicone.ai/custom/v1/log"
}
p.apiKey = apiKey
p.endpoint = endpoint
return nil
}
// SendEvents sends events to Helicone
func (p *HeliconePlugin) SendEvents(ctx context.Context, events []dispatcher.EventPayload) error {
if len(events) == 0 {
return nil
}
for _, event := range events {
// Skip events with empty output
if len(event.Output) == 0 && event.OutputBase64 == "" {
log.Printf("[helicone] Skipping event with empty output: RunID=%s, Path=%v", event.RunID, event.Metadata["path"])
continue
}
payload, err := heliconePayloadFromEvent(event)
if err != nil {
return err
}
if err := p.sendHeliconeEvent(ctx, payload); err != nil {
// Print payload for debugging on error
log.Printf("[helicone] Error sending event. Payload: %s", mustMarshalJSON(payload))
return err
}
}
return nil
}
// heliconePayloadFromEvent maps EventPayload to Helicone manual logger format
func heliconePayloadFromEvent(event dispatcher.EventPayload) (map[string]interface{}, error) {
// Extract request and response bodies
var reqBody map[string]interface{}
var respBody map[string]interface{}
isJSON := false
if len(event.Input) > 0 {
_ = json.Unmarshal(event.Input, &reqBody)
}
if reqBody == nil {
reqBody = map[string]interface{}{}
}
if len(event.Output) > 0 {
if err := json.Unmarshal(event.Output, &respBody); err == nil {
isJSON = true
}
}
// Timing (use event.Timestamp for both if no better info)
timestamp := event.Timestamp
sec := timestamp.Unix()
ms := timestamp.Nanosecond() / 1e6
timing := map[string]interface{}{
"startTime": map[string]int64{"seconds": sec, "milliseconds": int64(ms)},
"endTime": map[string]int64{"seconds": sec, "milliseconds": int64(ms)},
}
// Meta
meta := map[string]string{}
if event.UserID != nil {
meta["Helicone-User-Id"] = *event.UserID
}
if event.Extra != nil {
for k, v := range event.Extra {
if s, ok := v.(string); ok {
meta[k] = s
}
}
}
if event.Metadata != nil {
for k, v := range event.Metadata {
if s, ok := v.(string); ok {
meta[k] = s
}
}
// Ensure request_id is propagated if present in metadata
if rid, ok := event.Metadata["request_id"].(string); ok && rid != "" {
meta["request_id"] = rid
}
}
// Help Helicone infer pricing/provider when using manual logger
if _, ok := meta["provider"]; !ok {
meta["provider"] = "openai"
}
// ProviderRequest
// Prefer actual API path from metadata if present so Helicone can infer provider
urlPath := "custom-model-nopath"
if event.Metadata != nil {
if pth, ok := event.Metadata["path"].(string); ok && pth != "" {
urlPath = pth
}
}
providerRequest := map[string]interface{}{
"url": urlPath,
"json": reqBody,
"meta": meta,
}
// ProviderResponse
status := 200
if s, ok := event.Metadata["status"].(int); ok {
status = s
}
providerResponse := map[string]interface{}{
"status": status,
"headers": map[string]string{}, // Optionally fill from event.Metadata
}
// Ensure Helicone sees the provider as non-CUSTOM when using manual logger
if _, ok := meta["Helicone-Provider"]; !ok {
// Mirror meta["provider"] if set, otherwise default to openai
prov := meta["provider"]
if prov == "" {
prov = "openai"
}
meta["Helicone-Provider"] = prov
}
// Helicone requires providerResponse.json to be present; ensure it's always an object
if isJSON && respBody != nil {
providerResponse["json"] = respBody
} else {
providerResponse["json"] = map[string]interface{}{}
providerResponse["note"] = "response was not JSON; omitted"
}
// Inject usage if we computed it so Helicone can compute cost
if event.TokensUsage != nil {
if m, ok := providerResponse["json"].(map[string]interface{}); ok {
m["usage"] = map[string]int{
"prompt_tokens": event.TokensUsage.Prompt,
"completion_tokens": event.TokensUsage.Completion,
"total_tokens": event.TokensUsage.Prompt + event.TokensUsage.Completion,
}
}
}
// If OutputBase64 is set, include as base64
if event.OutputBase64 != "" {
providerResponse["base64"] = event.OutputBase64
}
return map[string]interface{}{
"providerRequest": providerRequest,
"providerResponse": providerResponse,
"timing": timing,
}, nil
}
// sendHeliconeEvent sends a single event to Helicone manual logger endpoint
func (p *HeliconePlugin) sendHeliconeEvent(ctx context.Context, payload map[string]interface{}) error {
data, err := json.Marshal(payload)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, "POST", p.endpoint, bytes.NewReader(data))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == 500 {
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(resp.Body)
log.Printf("Helicone API error %d: %s", resp.StatusCode, buf.String())
return &dispatcher.PermanentBackendError{Msg: fmt.Sprintf("helicone API returned status 500: %s", buf.String())}
}
if resp.StatusCode == http.StatusBadRequest { // 400: treat as permanent (payload/schema issue)
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(resp.Body)
log.Printf("Helicone API error %d: %s", resp.StatusCode, buf.String())
return &dispatcher.PermanentBackendError{Msg: fmt.Sprintf("helicone API returned status 400: %s", buf.String())}
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(resp.Body)
log.Printf("Helicone API error %d: %s", resp.StatusCode, buf.String())
return fmt.Errorf("helicone API returned status %d", resp.StatusCode)
}
return nil
}
// Close cleans up the plugin resources
func (p *HeliconePlugin) Close() error {
// Nothing to clean up for HTTP client
return nil
}
// mustMarshalJSON marshals v to JSON or returns an error string
func mustMarshalJSON(v interface{}) string {
b, err := json.Marshal(v)
if err != nil {
return "<marshal error>"
}
return string(b)
}
package plugins
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// LunaryPlugin implements Lunary.ai backend integration
type LunaryPlugin struct {
apiKey string
endpoint string
client *http.Client
}
// NewLunaryPlugin creates a new Lunary plugin
func NewLunaryPlugin() *LunaryPlugin {
return &LunaryPlugin{
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// Init initializes the Lunary plugin with configuration
func (p *LunaryPlugin) Init(cfg map[string]string) error {
apiKey, ok := cfg["api-key"]
if !ok || apiKey == "" {
return fmt.Errorf("lunary plugin requires 'api-key' configuration")
}
endpoint, ok := cfg["endpoint"]
if !ok || endpoint == "" {
endpoint = "https://api.lunary.ai/v1/runs/ingest"
}
p.apiKey = apiKey
p.endpoint = endpoint
return nil
}
// SendEvents sends events to Lunary.ai
func (p *LunaryPlugin) SendEvents(ctx context.Context, events []dispatcher.EventPayload) error {
if len(events) == 0 {
return nil
}
// Lunary expects an array of events
data, err := json.Marshal(events)
if err != nil {
return fmt.Errorf("failed to marshal events: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.endpoint, bytes.NewReader(data))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Printf("[lunary] failed to close response body: %v", err)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("lunary API returned status %d", resp.StatusCode)
}
return nil
}
// Close cleans up the plugin resources
func (p *LunaryPlugin) Close() error {
// Nothing to clean up for HTTP client
return nil
}
package plugins
import (
"fmt"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// PluginFactory is a function that creates a new plugin instance
type PluginFactory func() dispatcher.BackendPlugin
// Registry holds all available plugin factories
var Registry = make(map[string]PluginFactory)
// init registers all built-in plugins
func init() {
Registry["file"] = func() dispatcher.BackendPlugin {
return NewFilePlugin()
}
Registry["lunary"] = func() dispatcher.BackendPlugin {
return NewLunaryPlugin()
}
Registry["helicone"] = func() dispatcher.BackendPlugin {
return NewHeliconePlugin()
}
}
// NewPlugin creates a new plugin instance by name
func NewPlugin(name string) (dispatcher.BackendPlugin, error) {
factory, exists := Registry[name]
if !exists {
return nil, fmt.Errorf("unknown plugin: %s", name)
}
return factory(), nil
}
// ListPlugins returns a list of available plugin names
func ListPlugins() []string {
var names []string
for name := range Registry {
names = append(names, name)
}
return names
}
package dispatcher
import (
"context"
"fmt"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"time"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"go.uber.org/zap"
)
// Config holds configuration for the dispatcher service
type Config struct {
BufferSize int
BatchSize int
FlushInterval time.Duration
RetryAttempts int
RetryBackoff time.Duration
Plugin BackendPlugin
EventTransformer EventTransformer
PluginName string
Verbose bool // If true, include response_headers and extra debug info
}
// Service represents the event dispatcher service
type Service struct {
config Config
eventBus eventbus.EventBus
logger *zap.Logger
stopCh chan struct{}
wg sync.WaitGroup
stopOnce sync.Once
// startedCh is closed after the event processing goroutine has been added
// to the WaitGroup. This avoids a data race between Wait() and Add(1).
startedCh chan struct{}
// metrics
mu sync.Mutex
eventsProcessed int64
eventsDropped int64
eventsSent int64
}
// NewService creates a new dispatcher service
func NewService(cfg Config, logger *zap.Logger) (*Service, error) {
if cfg.Plugin == nil {
return nil, fmt.Errorf("backend plugin is required")
}
if cfg.EventTransformer == nil {
cfg.EventTransformer = NewDefaultEventTransformer(cfg.Verbose)
}
if cfg.BufferSize <= 0 {
cfg.BufferSize = 1000
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100
}
if cfg.FlushInterval <= 0 {
cfg.FlushInterval = 5 * time.Second
}
if cfg.RetryAttempts <= 0 {
cfg.RetryAttempts = 3
}
if cfg.RetryBackoff <= 0 {
cfg.RetryBackoff = time.Second
}
if logger == nil {
logger = zap.NewNop()
}
// Create event bus for the dispatcher
bus := eventbus.NewInMemoryEventBus(cfg.BufferSize)
return &Service{
config: cfg,
eventBus: bus,
logger: logger,
stopCh: make(chan struct{}),
startedCh: make(chan struct{}),
}, nil
}
// NewServiceWithBus creates a new dispatcher service with a provided event bus.
func NewServiceWithBus(cfg Config, logger *zap.Logger, bus eventbus.EventBus) (*Service, error) {
if cfg.Plugin == nil {
return nil, fmt.Errorf("backend plugin is required")
}
if cfg.EventTransformer == nil {
cfg.EventTransformer = NewDefaultEventTransformer(cfg.Verbose)
}
if cfg.BufferSize <= 0 {
cfg.BufferSize = 1000
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100
}
if cfg.FlushInterval <= 0 {
cfg.FlushInterval = 5 * time.Second
}
if cfg.RetryAttempts <= 0 {
cfg.RetryAttempts = 3
}
if cfg.RetryBackoff <= 0 {
cfg.RetryBackoff = time.Second
}
if logger == nil {
logger = zap.NewNop()
}
if bus == nil {
return nil, fmt.Errorf("event bus must not be nil")
}
return &Service{
config: cfg,
eventBus: bus,
logger: logger,
stopCh: make(chan struct{}),
startedCh: make(chan struct{}),
}, nil
}
// Run starts the dispatcher service and blocks until stopped
func (s *Service) Run(ctx context.Context, detach bool) error {
s.logger.Info("Starting event dispatcher service")
if detach {
return s.runDetached(ctx)
}
return s.runForeground(ctx)
}
// runDetached runs the service in background mode
func (s *Service) runDetached(ctx context.Context) error {
s.logger.Info("Running in detached mode")
// For detached mode, we still run in foreground but could be enhanced
// to fork the process or use systemd/supervisor in production
return s.runForeground(ctx)
}
// runForeground runs the service in foreground mode
func (s *Service) runForeground(ctx context.Context) error {
// Handle graceful shutdown
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
// Start event processing goroutine
s.wg.Add(1)
// Signal that Add(1) has completed before any potential Stop() Wait
close(s.startedCh)
go s.processEvents(ctx)
// Wait for shutdown signal
select {
case <-sigs:
s.logger.Info("Received shutdown signal")
case <-ctx.Done():
s.logger.Info("Context cancelled")
}
return s.Stop()
}
// Stop gracefully stops the dispatcher service
func (s *Service) Stop() error {
s.stopOnce.Do(func() {
s.logger.Info("Stopping event dispatcher service")
close(s.stopCh)
// Avoid racing Wait() with a concurrent Add(1) during startup.
select {
case <-s.startedCh:
s.wg.Wait()
default:
// Not started; nothing to wait for
}
if s.eventBus != nil {
s.eventBus.Stop()
}
if s.config.Plugin != nil {
if err := s.config.Plugin.Close(); err != nil {
s.logger.Error("Error closing plugin", zap.Error(err))
}
}
s.logger.Info("Event dispatcher service stopped")
})
return nil
}
// EventBus returns the event bus for this service (for connecting with other components)
func (s *Service) EventBus() eventbus.EventBus {
return s.eventBus
}
// processEvents handles the main event processing loop
func (s *Service) processEvents(ctx context.Context) {
defer s.wg.Done()
sub := s.eventBus.Subscribe()
batch := make([]EventPayload, 0, s.config.BatchSize)
ticker := time.NewTicker(s.config.FlushInterval)
defer ticker.Stop()
// Detect if Subscribe returns a closed channel (log-based bus)
closed := false
select {
case _, ok := <-sub:
if !ok {
closed = true
}
default:
}
if closed {
// Log-based consumption: poll for new events using ReadEvents
// Persist last-seen LogID in Redis for this dispatcher
client, ok := s.eventBus.(*eventbus.RedisEventBus)
if !ok {
s.logger.Error("eventBus is not RedisEventBus; cannot persist offset")
return
}
redisClient := client.Client()
// Use the dispatcher type (plugin name) for offset key: one dispatcher per type
serviceType := s.config.PluginName
if serviceType == "" {
serviceType = "default"
}
dispatcherKey := "llm-proxy-dispatcher:" + serviceType + ":last_id"
// Read last-seen LogID from Redis
var lastSeenID int64 = 0
if val, err := redisClient.Get(ctx, dispatcherKey); err == nil && val != "" {
if id, err := strconv.ParseInt(val, 10, 64); err == nil {
lastSeenID = id
}
}
for {
select {
case <-ticker.C:
ctxPoll, cancel := context.WithTimeout(ctx, s.config.FlushInterval)
defer cancel()
events, err := s.eventBus.(*eventbus.RedisEventBus).ReadEvents(ctxPoll, 0, -1)
if err != nil {
s.logger.Error("Failed to read events from log", zap.Error(err))
continue
}
// Filter events with LogID > lastSeenID
newEvents := make([]eventbus.Event, 0)
for _, evt := range events {
if evt.LogID > lastSeenID {
newEvents = append(newEvents, evt)
}
}
// Reverse newEvents to process from oldest to newest
for i, j := 0, len(newEvents)-1; i < j; i, j = i+1, j-1 {
newEvents[i], newEvents[j] = newEvents[j], newEvents[i]
}
if len(newEvents) > 0 {
if newEvents[0].LogID > lastSeenID+1 {
s.logger.Warn("Missed events due to TTL or trimming", zap.Int64("last_seen_id", lastSeenID), zap.Int64("first_log_id", newEvents[0].LogID))
}
// Prepare batch and track maxLogID in this batch
maxLogID := lastSeenID
batch = batch[:0]
for _, evt := range newEvents {
payload, err := s.config.EventTransformer.Transform(evt)
if err != nil {
s.mu.Lock()
s.eventsDropped++
s.mu.Unlock()
s.logger.Error("Failed to transform event", zap.Error(err))
continue
}
if payload == nil {
continue
}
batch = append(batch, *payload)
s.mu.Lock()
s.eventsProcessed++
s.mu.Unlock()
if evt.LogID > maxLogID {
maxLogID = evt.LogID
}
}
if len(batch) > 0 {
// Only update/persist lastSeenID if sendBatch succeeds
err := s.sendBatchWithResult(ctx, batch)
if err == nil {
lastSeenID = maxLogID
_ = redisClient.Set(ctx, dispatcherKey, strconv.FormatInt(lastSeenID, 10))
}
batch = batch[:0]
}
}
case <-s.stopCh:
return
case <-ctx.Done():
return
}
}
}
// Channel-based (in-memory) event bus
for {
select {
case evt, ok := <-sub:
if !ok {
// Channel closed, flush remaining batch and exit
if len(batch) > 0 {
s.sendBatch(ctx, batch)
}
return
}
// Transform the event
payload, err := s.config.EventTransformer.Transform(evt)
if err != nil {
s.mu.Lock()
s.eventsDropped++
s.mu.Unlock()
s.logger.Error("Failed to transform event", zap.Error(err))
continue
}
if payload == nil {
// Event was filtered out (e.g., OPTIONS request)
continue
}
batch = append(batch, *payload)
s.mu.Lock()
s.eventsProcessed++
s.mu.Unlock()
// Send batch if it's full
if len(batch) >= s.config.BatchSize {
s.sendBatch(ctx, batch)
batch = batch[:0] // Reset slice
}
case <-ticker.C:
// Flush batch on timer
if len(batch) > 0 {
s.sendBatch(ctx, batch)
batch = batch[:0] // Reset slice
}
case <-s.stopCh:
// Flush remaining batch and exit
if len(batch) > 0 {
s.sendBatch(ctx, batch)
}
return
}
}
}
// sendBatch sends a batch of events to the configured backend with retry logic
func (s *Service) sendBatch(ctx context.Context, batch []EventPayload) {
for attempt := 0; attempt <= s.config.RetryAttempts; attempt++ {
err := s.config.Plugin.SendEvents(ctx, batch)
if err == nil {
s.mu.Lock()
s.eventsSent += int64(len(batch))
s.mu.Unlock()
s.logger.Debug("Successfully sent batch",
zap.Int("batch_size", len(batch)),
zap.Int("attempt", attempt+1))
return
}
if attempt < s.config.RetryAttempts {
backoff := time.Duration(attempt+1) * s.config.RetryBackoff
s.logger.Warn("Failed to send batch, retrying",
zap.Error(err),
zap.Int("attempt", attempt+1),
zap.Duration("backoff", backoff))
select {
case <-time.After(backoff):
case <-ctx.Done():
return
case <-s.stopCh:
return
}
} else {
s.logger.Error("Failed to send batch after all retries",
zap.Error(err),
zap.Int("batch_size", len(batch)))
s.mu.Lock()
s.eventsDropped += int64(len(batch))
s.mu.Unlock()
}
}
}
// sendBatchWithResult sends a batch of events to the configured backend with retry logic and returns the result
func (s *Service) sendBatchWithResult(ctx context.Context, batch []EventPayload) error {
for attempt := 0; attempt <= s.config.RetryAttempts; attempt++ {
err := s.config.Plugin.SendEvents(ctx, batch)
if err == nil {
s.mu.Lock()
s.eventsSent += int64(len(batch))
s.mu.Unlock()
s.logger.Debug("Successfully sent batch",
zap.Int("batch_size", len(batch)),
zap.Int("attempt", attempt+1))
return nil
}
// If PermanentBackendError, treat as delivered and do not retry
if _, ok := err.(*PermanentBackendError); ok {
s.logger.Warn("Permanent backend error, skipping batch", zap.Error(err), zap.Int("batch_size", len(batch)))
s.mu.Lock()
s.eventsDropped += int64(len(batch))
s.mu.Unlock()
return nil // treat as delivered
}
if attempt < s.config.RetryAttempts {
backoff := time.Duration(attempt+1) * s.config.RetryBackoff
s.logger.Warn("Failed to send batch, retrying",
zap.Error(err),
zap.Int("attempt", attempt+1),
zap.Duration("backoff", backoff))
select {
case <-time.After(backoff):
case <-ctx.Done():
return ctx.Err()
case <-s.stopCh:
return fmt.Errorf("stopped")
}
} else {
s.logger.Error("Failed to send batch after all retries",
zap.Error(err),
zap.Int("batch_size", len(batch)))
s.mu.Lock()
s.eventsDropped += int64(len(batch))
s.mu.Unlock()
return err
}
}
return fmt.Errorf("unreachable")
}
// Stats returns service statistics
func (s *Service) Stats() (processed, dropped, sent int64) {
s.mu.Lock()
defer s.mu.Unlock()
return s.eventsProcessed, s.eventsDropped, s.eventsSent
}
package dispatcher
import (
"bytes"
"compress/gzip"
"encoding/json"
"io"
"log"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/andybalholm/brotli"
"github.com/google/uuid"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"github.com/sofatutor/llm-proxy/internal/eventtransformer"
)
// DefaultEventTransformer provides a basic transformation from eventbus.Event to EventPayload
// Verbose: if true, includes response_headers in metadata
// Use NewDefaultEventTransformer(verbose) to construct
type DefaultEventTransformer struct {
Verbose bool
}
// NewDefaultEventTransformer creates a transformer with the given verbosity
func NewDefaultEventTransformer(verbose bool) *DefaultEventTransformer {
return &DefaultEventTransformer{Verbose: verbose}
}
// cleanJSONBinary recursively replaces binary fields in a JSON object with a placeholder
func cleanJSONBinary(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
for k, v2 := range val {
val[k] = cleanJSONBinary(v2)
}
return val
case []interface{}:
for i, v2 := range val {
val[i] = cleanJSONBinary(v2)
}
return val
case string:
if !utf8.ValidString(val) {
return "<binary omitted>"
}
return val
case []byte:
if !utf8.Valid(val) {
return "<binary omitted>"
}
return string(val)
default:
return val
}
}
// safeRawMessageOrBase64 tries to decode data as JSON, decompressing with gzip or brotli if needed, then as UTF-8 string, else returns base64 string
// If Content-Type is JSON, always return cleaned JSON (with binary fields replaced)
func safeRawMessageOrBase64(data []byte, headers map[string][]string) (json.RawMessage, string) {
if len(data) == 0 {
return nil, ""
}
var js json.RawMessage
// Check for Content-Encoding
encoding := ""
contentType := ""
if headers != nil {
if v, ok := headers["Content-Encoding"]; ok && len(v) > 0 {
encoding = v[0]
}
if v, ok := headers["Content-Type"]; ok && len(v) > 0 {
contentType = v[0]
}
}
// If Content-Type is multipart, return a placeholder
if strings.Contains(contentType, "multipart") {
return []byte(strconv.Quote("<multipart response omitted>")), ""
}
decompressed := data
var decompressErr error
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader(data))
if err == nil {
decompressed, decompressErr = io.ReadAll(zr)
_ = zr.Close()
} else {
decompressErr = err
}
case "br":
br := brotli.NewReader(bytes.NewReader(data))
var err error
decompressed, err = io.ReadAll(br)
if err != nil {
decompressErr = err
}
}
// If Content-Type is JSON, always return cleaned JSON
if strings.Contains(contentType, "json") {
var obj interface{}
if json.Unmarshal(decompressed, &obj) == nil {
cleaned := cleanJSONBinary(obj)
if jsBytes, err := json.Marshal(cleaned); err == nil {
return jsBytes, ""
}
}
}
if decompressErr == nil && json.Unmarshal(decompressed, &js) == nil {
return js, ""
} else if decompressErr != nil {
log.Printf("[transformer] Decompression failed: %v", decompressErr)
} else if json.Unmarshal(decompressed, &js) != nil {
if strings.Contains(contentType, "json") {
log.Printf("[transformer] JSON unmarshal after decompress failed: %v First 64 bytes: %x", decompressErr, decompressed[:min(64, len(decompressed))])
}
}
// Try direct JSON unmarshal if not already tried
if decompressErr != nil && json.Unmarshal(data, &js) == nil {
return js, ""
} else if decompressErr != nil {
log.Printf("[transformer] JSON unmarshal failed: %v First 64 bytes: %x", decompressErr, data[:min(64, len(data))])
}
// If valid UTF-8, try to parse as JSON string or OpenAI event stream
if utf8.Valid(decompressed) {
str := string(decompressed)
trim := strings.TrimSpace(str)
if (strings.HasPrefix(trim, "{") && strings.HasSuffix(trim, "}")) ||
(strings.HasPrefix(trim, "[") && strings.HasSuffix(trim, "]")) {
// Looks like JSON object/array in a string
if json.Unmarshal([]byte(trim), &js) == nil {
return js, ""
}
}
// OpenAI streaming or event lines
if eventtransformer.IsOpenAIStreaming(str) {
if merged, err := eventtransformer.MergeOpenAIStreamingChunks(str); err == nil {
if js, err := json.Marshal(merged); err == nil {
return js, ""
}
}
}
if strings.Contains(str, "event: ") && strings.Contains(str, "data: ") {
if merged, err := eventtransformer.MergeThreadStreamingChunks(str); err == nil {
if js, err := json.Marshal(merged); err == nil {
return js, ""
}
}
}
// Fallback: log as JSON string
quoted := []byte(strconv.Quote(str))
return quoted, ""
}
// For binary data, return a placeholder string instead of base64
return []byte(strconv.Quote("<binary response omitted>")), ""
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// Transform converts an eventbus.Event to an EventPayload
func (t *DefaultEventTransformer) Transform(evt eventbus.Event) (*EventPayload, error) {
// Skip non-POST requests (like OPTIONS, GET)
if evt.Method != "POST" {
return nil, nil
}
// Generate a unique run ID for this event
runID := uuid.New().String()
// Basic transformation
payload := &EventPayload{
Type: "llm",
Event: "start", // For now, all events are considered "start" events
RunID: runID,
Timestamp: time.Now(),
LogID: evt.LogID,
Metadata: map[string]any{
"method": evt.Method,
"path": evt.Path,
"status": evt.Status,
"duration_ms": evt.Duration.Milliseconds(),
"request_id": evt.RequestID,
},
}
// Add request body as input (JSON or base64)
if len(evt.RequestBody) > 0 {
if js, b64 := safeRawMessageOrBase64(evt.RequestBody, nil); js != nil {
payload.Input = js
} else {
payload.InputBase64 = b64
}
}
// --- OpenAI-specific output transformation ---
isOpenAI := strings.HasPrefix(evt.Path, "/v1/completions") ||
strings.HasPrefix(evt.Path, "/v1/chat/completions") ||
strings.HasPrefix(evt.Path, "/v1/threads/")
if isOpenAI && len(evt.ResponseBody) > 0 {
// Only use OpenAI transformer if response is valid JSON
if js := json.Valid(evt.ResponseBody); js {
openaiTransformer := &eventtransformer.OpenAITransformer{}
parsed, err := openaiTransformer.TransformEvent(map[string]any{
"response_body": string(evt.ResponseBody),
"path": evt.Path,
})
if err == nil && parsed != nil {
if js, err := json.Marshal(parsed); err == nil {
payload.Output = js
// Optionally extract token usage if present
if usage, ok := parsed["usage"].(map[string]any); ok {
payload.TokensUsage = &TokensUsage{
Prompt: int(usage["prompt_tokens"].(float64)),
Completion: int(usage["completion_tokens"].(float64)),
}
}
return payload, nil
}
}
// If parsing fails, fall through to generic logic
}
}
// Add response body as output (JSON or base64)
if len(evt.ResponseBody) > 0 {
if js, b64 := safeRawMessageOrBase64(evt.ResponseBody, evt.ResponseHeaders); js != nil {
payload.Output = js
} else {
payload.OutputBase64 = b64
}
}
// Add response headers to metadata only if Verbose is true
if t.Verbose && evt.ResponseHeaders != nil {
headers := make(map[string]any)
for k, v := range evt.ResponseHeaders {
if len(v) == 1 {
headers[k] = v[0]
} else {
headers[k] = v
}
}
payload.Metadata["response_headers"] = headers
}
return payload, nil
}
// Package encryption provides utilities for encrypting and decrypting sensitive data.
// It uses AES-256-GCM for symmetric encryption of data at rest.
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
)
const (
// KeySize is the required size for AES-256 encryption keys (32 bytes).
KeySize = 32
// NonceSize is the size of the GCM nonce (12 bytes).
NonceSize = 12
// EncryptedPrefix is added to encrypted values to identify them.
EncryptedPrefix = "enc:v1:"
)
var (
// ErrInvalidKeySize is returned when the encryption key has an invalid size.
ErrInvalidKeySize = errors.New("encryption key must be exactly 32 bytes")
// ErrDecryptionFailed is returned when decryption fails.
ErrDecryptionFailed = errors.New("decryption failed")
// ErrNoEncryptionKey is returned when no encryption key is configured.
ErrNoEncryptionKey = errors.New("no encryption key configured")
// ErrInvalidCiphertext is returned when the ciphertext is invalid.
ErrInvalidCiphertext = errors.New("invalid ciphertext format")
)
// Encryptor provides encryption and decryption operations.
// It is safe for concurrent use - cipher.AEAD implementations are thread-safe.
type Encryptor struct {
gcm cipher.AEAD
}
// NewEncryptor creates a new Encryptor with the given 32-byte key.
// The key must be exactly 32 bytes for AES-256 encryption.
func NewEncryptor(key []byte) (*Encryptor, error) {
if len(key) != KeySize {
return nil, ErrInvalidKeySize
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
return &Encryptor{gcm: gcm}, nil
}
// NewEncryptorFromBase64Key creates a new Encryptor from a base64-encoded key.
func NewEncryptorFromBase64Key(base64Key string) (*Encryptor, error) {
if base64Key == "" {
return nil, ErrNoEncryptionKey
}
key, err := base64.StdEncoding.DecodeString(base64Key)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 key: %w", err)
}
return NewEncryptor(key)
}
// Encrypt encrypts plaintext and returns a base64-encoded ciphertext with prefix.
func (e *Encryptor) Encrypt(plaintext string) (string, error) {
if plaintext == "" {
return "", nil // Empty strings are not encrypted
}
// Generate a random nonce
nonce := make([]byte, NonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt the plaintext
ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode to base64 and add prefix
encoded := base64.StdEncoding.EncodeToString(ciphertext)
return EncryptedPrefix + encoded, nil
}
// Decrypt decrypts a base64-encoded ciphertext and returns the plaintext.
// If the value is not encrypted (no prefix), it returns the value as-is.
func (e *Encryptor) Decrypt(ciphertext string) (string, error) {
if ciphertext == "" {
return "", nil // Empty strings are returned as-is
}
// Check if the value is encrypted
if !IsEncrypted(ciphertext) {
return ciphertext, nil // Return unencrypted values as-is (backward compatibility)
}
// Remove the prefix
encoded := ciphertext[len(EncryptedPrefix):]
// Decode from base64
data, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
// Validate minimum length (nonce + tag; plaintext can be empty)
if len(data) < NonceSize+e.gcm.Overhead() {
return "", ErrInvalidCiphertext
}
// Extract nonce and ciphertext
nonce := data[:NonceSize]
encryptedData := data[NonceSize:]
// Decrypt
plaintext, err := e.gcm.Open(nil, nonce, encryptedData, nil)
if err != nil {
return "", ErrDecryptionFailed
}
return string(plaintext), nil
}
// IsEncrypted checks if a value has the encryption prefix.
func IsEncrypted(value string) bool {
return len(value) > len(EncryptedPrefix) && value[:len(EncryptedPrefix)] == EncryptedPrefix
}
// GenerateKey generates a new random 32-byte encryption key.
func GenerateKey() ([]byte, error) {
key := make([]byte, KeySize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
return key, nil
}
// GenerateKeyBase64 generates a new random encryption key and returns it as base64.
func GenerateKeyBase64() (string, error) {
key, err := GenerateKey()
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
// NullEncryptor is a no-op encryptor for when encryption is disabled.
type NullEncryptor struct{}
// NewNullEncryptor creates a new NullEncryptor.
func NewNullEncryptor() *NullEncryptor {
return &NullEncryptor{}
}
// Encrypt returns the plaintext as-is (no encryption).
func (e *NullEncryptor) Encrypt(plaintext string) (string, error) {
return plaintext, nil
}
// Decrypt returns the ciphertext as-is (no decryption).
func (e *NullEncryptor) Decrypt(ciphertext string) (string, error) {
return ciphertext, nil
}
// FieldEncryptor is an interface for encrypting and decrypting field values.
type FieldEncryptor interface {
Encrypt(plaintext string) (string, error)
Decrypt(ciphertext string) (string, error)
}
// Compile-time interface checks
var (
_ FieldEncryptor = (*Encryptor)(nil)
_ FieldEncryptor = (*NullEncryptor)(nil)
)
// Package encryption provides utilities for hashing sensitive data.
// It uses bcrypt for secure password-like hashing of tokens.
package encryption
import (
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"errors"
"fmt"
"golang.org/x/crypto/bcrypt"
)
const (
// HashPrefix is added to hashed values to identify them.
HashPrefix = "hash:v1:"
// DefaultBcryptCost is the default cost parameter for bcrypt.
// A cost of 10 is a good balance between security and performance.
DefaultBcryptCost = 10
)
var (
// ErrHashMismatch is returned when a hash comparison fails.
ErrHashMismatch = errors.New("hash does not match")
// ErrInvalidHash is returned when the hash format is invalid.
ErrInvalidHash = errors.New("invalid hash format")
)
// TokenHasher provides secure hashing for authentication tokens.
// It uses SHA-256 for creating lookup keys and bcrypt for secure storage.
type TokenHasher struct {
bcryptCost int
}
// NewTokenHasher creates a new TokenHasher with the default bcrypt cost.
func NewTokenHasher() *TokenHasher {
return &TokenHasher{bcryptCost: DefaultBcryptCost}
}
// NewTokenHasherWithCost creates a new TokenHasher with a custom bcrypt cost.
func NewTokenHasherWithCost(cost int) (*TokenHasher, error) {
if cost < bcrypt.MinCost || cost > bcrypt.MaxCost {
return nil, fmt.Errorf("bcrypt cost must be between %d and %d", bcrypt.MinCost, bcrypt.MaxCost)
}
return &TokenHasher{bcryptCost: cost}, nil
}
// HashToken creates a bcrypt hash of a token for secure storage.
// Returns a hash prefixed with HashPrefix for identification.
// For tokens longer than 72 bytes, a SHA-256 pre-hash is used since
// bcrypt has a 72-byte input limit.
func (h *TokenHasher) HashToken(token string) (string, error) {
if token == "" {
return "", errors.New("token cannot be empty")
}
// Pre-hash if token is too long for bcrypt (72 byte limit)
input := []byte(token)
if len(input) > 72 {
sum := sha256.Sum256(input)
input = sum[:]
}
hash, err := bcrypt.GenerateFromPassword(input, h.bcryptCost)
if err != nil {
return "", fmt.Errorf("failed to hash token: %w", err)
}
return HashPrefix + string(hash), nil
}
// VerifyToken compares a plaintext token against a stored hash.
// It returns nil if the token matches, or ErrHashMismatch if it doesn't.
func (h *TokenHasher) VerifyToken(token, hashedToken string) error {
if token == "" || hashedToken == "" {
return ErrHashMismatch
}
// Check if it's a hashed value
if !IsHashed(hashedToken) {
// For backward compatibility, do a constant-time comparison
// if the stored value is not hashed
if subtle.ConstantTimeCompare([]byte(token), []byte(hashedToken)) == 1 {
return nil
}
return ErrHashMismatch
}
// Remove the prefix
bcryptHash := hashedToken[len(HashPrefix):]
// Pre-hash if token is too long for bcrypt (72 byte limit)
input := []byte(token)
if len(input) > 72 {
sum := sha256.Sum256(input)
input = sum[:]
}
// Verify with bcrypt
err := bcrypt.CompareHashAndPassword([]byte(bcryptHash), input)
if err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return ErrHashMismatch
}
return fmt.Errorf("failed to verify token: %w", err)
}
return nil
}
// CreateLookupKey creates a deterministic hash for token lookup.
// This is used as an index key in the database for finding tokens.
// Uses SHA-256 which is fast and collision-resistant.
func (h *TokenHasher) CreateLookupKey(token string) string {
if token == "" {
return ""
}
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// IsHashed checks if a value has the hash prefix.
func IsHashed(value string) bool {
return len(value) > len(HashPrefix) && value[:len(HashPrefix)] == HashPrefix
}
// NullTokenHasher is a no-op hasher for when hashing is disabled.
type NullTokenHasher struct{}
// NewNullTokenHasher creates a new NullTokenHasher.
func NewNullTokenHasher() *NullTokenHasher {
return &NullTokenHasher{}
}
// HashToken returns the token as-is (no hashing).
func (h *NullTokenHasher) HashToken(token string) (string, error) {
return token, nil
}
// VerifyToken performs a constant-time comparison of the tokens.
func (h *NullTokenHasher) VerifyToken(token, storedToken string) error {
if subtle.ConstantTimeCompare([]byte(token), []byte(storedToken)) == 1 {
return nil
}
return ErrHashMismatch
}
// CreateLookupKey returns the token as-is (it's already the lookup key).
func (h *NullTokenHasher) CreateLookupKey(token string) string {
return token
}
// TokenHasherInterface defines the interface for token hashing operations.
type TokenHasherInterface interface {
HashToken(token string) (string, error)
VerifyToken(token, hashedToken string) error
CreateLookupKey(token string) string
}
// Compile-time interface checks
var (
_ TokenHasherInterface = (*TokenHasher)(nil)
_ TokenHasherInterface = (*NullTokenHasher)(nil)
)
// Package encryption provides a secure database wrapper that encrypts/decrypts sensitive fields.
package encryption
import (
"context"
"fmt"
"github.com/sofatutor/llm-proxy/internal/proxy"
)
// SecureProjectStore wraps a ProjectStore and encrypts/decrypts API keys.
type SecureProjectStore struct {
store proxy.ProjectStore
encryptor FieldEncryptor
}
// NewSecureProjectStore creates a new SecureProjectStore.
// The encryptor is used to encrypt API keys before storing and decrypt after retrieval.
// If encryptor is nil, a NullEncryptor is used (no encryption).
func NewSecureProjectStore(store proxy.ProjectStore, encryptor FieldEncryptor) *SecureProjectStore {
if encryptor == nil {
encryptor = NewNullEncryptor()
}
return &SecureProjectStore{
store: store,
encryptor: encryptor,
}
}
// GetAPIKeyForProject retrieves and decrypts the API key for a project.
func (s *SecureProjectStore) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
encryptedKey, err := s.store.GetAPIKeyForProject(ctx, projectID)
if err != nil {
return "", err
}
// Decrypt the API key
decryptedKey, err := s.encryptor.Decrypt(encryptedKey)
if err != nil {
return "", fmt.Errorf("failed to decrypt API key: %w", err)
}
return decryptedKey, nil
}
// GetProjectActive returns whether a project is active.
func (s *SecureProjectStore) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
return s.store.GetProjectActive(ctx, projectID)
}
// ListProjects retrieves all projects and decrypts their API keys.
func (s *SecureProjectStore) ListProjects(ctx context.Context) ([]proxy.Project, error) {
projects, err := s.store.ListProjects(ctx)
if err != nil {
return nil, err
}
// Decrypt API keys for each project
for i := range projects {
decryptedKey, err := s.encryptor.Decrypt(projects[i].OpenAIAPIKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt API key for project %s: %w", projects[i].ID, err)
}
projects[i].OpenAIAPIKey = decryptedKey
}
return projects, nil
}
// CreateProject encrypts the API key and creates the project.
func (s *SecureProjectStore) CreateProject(ctx context.Context, project proxy.Project) error {
// Encrypt the API key before storing
encryptedKey, err := s.encryptor.Encrypt(project.OpenAIAPIKey)
if err != nil {
return fmt.Errorf("failed to encrypt API key: %w", err)
}
project.OpenAIAPIKey = encryptedKey
return s.store.CreateProject(ctx, project)
}
// GetProjectByID retrieves a project and decrypts its API key.
func (s *SecureProjectStore) GetProjectByID(ctx context.Context, projectID string) (proxy.Project, error) {
project, err := s.store.GetProjectByID(ctx, projectID)
if err != nil {
return proxy.Project{}, err
}
// Decrypt the API key
decryptedKey, err := s.encryptor.Decrypt(project.OpenAIAPIKey)
if err != nil {
return proxy.Project{}, fmt.Errorf("failed to decrypt API key: %w", err)
}
project.OpenAIAPIKey = decryptedKey
return project, nil
}
// UpdateProject encrypts the API key and updates the project.
func (s *SecureProjectStore) UpdateProject(ctx context.Context, project proxy.Project) error {
// Only encrypt if the API key is not already encrypted
if !IsEncrypted(project.OpenAIAPIKey) {
encryptedKey, err := s.encryptor.Encrypt(project.OpenAIAPIKey)
if err != nil {
return fmt.Errorf("failed to encrypt API key: %w", err)
}
project.OpenAIAPIKey = encryptedKey
}
return s.store.UpdateProject(ctx, project)
}
// DeleteProject deletes a project.
func (s *SecureProjectStore) DeleteProject(ctx context.Context, projectID string) error {
return s.store.DeleteProject(ctx, projectID)
}
// Compile-time interface check
var _ proxy.ProjectStore = (*SecureProjectStore)(nil)
// Package encryption provides a secure token store wrapper that hashes tokens.
package encryption
import (
"context"
"github.com/sofatutor/llm-proxy/internal/token"
)
// SecureTokenStore wraps a TokenStore and hashes tokens before storage.
// This prevents tokens from being exposed if the database is compromised.
type SecureTokenStore struct {
store token.TokenStore
hasher TokenHasherInterface
}
// NewSecureTokenStore creates a new SecureTokenStore.
// If hasher is nil, a NullTokenHasher is used (no hashing).
func NewSecureTokenStore(store token.TokenStore, hasher TokenHasherInterface) *SecureTokenStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureTokenStore{
store: store,
hasher: hasher,
}
}
// GetTokenByID retrieves a token by its ID (the original plaintext token).
// The token is hashed before lookup, and the returned TokenData will
// have the hashed token value (not the original).
func (s *SecureTokenStore) GetTokenByID(ctx context.Context, tokenID string) (token.TokenData, error) {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.GetTokenByID(ctx, hashedToken)
}
// IncrementTokenUsage increments the usage count for a token.
// The token is hashed before the operation.
func (s *SecureTokenStore) IncrementTokenUsage(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.IncrementTokenUsage(ctx, hashedToken)
}
// CreateToken creates a new token in the store.
// The token value is hashed before storage.
func (s *SecureTokenStore) CreateToken(ctx context.Context, td token.TokenData) error {
// Hash the token value
td.Token = s.hasher.CreateLookupKey(td.Token)
return s.store.CreateToken(ctx, td)
}
// UpdateToken updates an existing token.
// The token value is hashed before the operation.
func (s *SecureTokenStore) UpdateToken(ctx context.Context, td token.TokenData) error {
// Hash the token value if it's not already a SHA-256 hex string (64 hex chars)
// Validate both length and hex content to avoid skipping plaintext 64-char tokens
if len(td.Token) != 64 || !IsHexString(td.Token) {
td.Token = s.hasher.CreateLookupKey(td.Token)
}
return s.store.UpdateToken(ctx, td)
}
// IsHexString checks if a string contains only hexadecimal characters.
// Exported for use by migration tools and other packages.
func IsHexString(s string) bool {
for _, c := range s {
isDigit := c >= '0' && c <= '9'
isLowerHex := c >= 'a' && c <= 'f'
isUpperHex := c >= 'A' && c <= 'F'
if !isDigit && !isLowerHex && !isUpperHex {
return false
}
}
return true
}
// ListTokens retrieves all tokens from the store.
// Note: The returned tokens will have hashed token values.
func (s *SecureTokenStore) ListTokens(ctx context.Context) ([]token.TokenData, error) {
return s.store.ListTokens(ctx)
}
// GetTokensByProjectID retrieves all tokens for a project.
// Note: The returned tokens will have hashed token values.
func (s *SecureTokenStore) GetTokensByProjectID(ctx context.Context, projectID string) ([]token.TokenData, error) {
return s.store.GetTokensByProjectID(ctx, projectID)
}
// Compile-time interface check
var _ token.TokenStore = (*SecureTokenStore)(nil)
// SecureRevocationStore wraps a RevocationStore and hashes tokens before operations.
type SecureRevocationStore struct {
store token.RevocationStore
hasher TokenHasherInterface
}
// NewSecureRevocationStore creates a new SecureRevocationStore.
func NewSecureRevocationStore(store token.RevocationStore, hasher TokenHasherInterface) *SecureRevocationStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureRevocationStore{
store: store,
hasher: hasher,
}
}
// RevokeToken revokes a token by its ID.
func (s *SecureRevocationStore) RevokeToken(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.RevokeToken(ctx, hashedToken)
}
// DeleteToken deletes a token by its ID.
func (s *SecureRevocationStore) DeleteToken(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.DeleteToken(ctx, hashedToken)
}
// RevokeBatchTokens revokes multiple tokens at once.
func (s *SecureRevocationStore) RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error) {
hashedTokens := make([]string, len(tokenIDs))
for i, t := range tokenIDs {
hashedTokens[i] = s.hasher.CreateLookupKey(t)
}
return s.store.RevokeBatchTokens(ctx, hashedTokens)
}
// RevokeProjectTokens revokes all tokens for a project.
func (s *SecureRevocationStore) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
return s.store.RevokeProjectTokens(ctx, projectID)
}
// RevokeExpiredTokens revokes all expired tokens.
func (s *SecureRevocationStore) RevokeExpiredTokens(ctx context.Context) (int, error) {
return s.store.RevokeExpiredTokens(ctx)
}
// Compile-time interface check
var _ token.RevocationStore = (*SecureRevocationStore)(nil)
// SecureRateLimitStore wraps a RateLimitStore and hashes tokens before operations.
type SecureRateLimitStore struct {
store token.RateLimitStore
hasher TokenHasherInterface
}
// NewSecureRateLimitStore creates a new SecureRateLimitStore.
func NewSecureRateLimitStore(store token.RateLimitStore, hasher TokenHasherInterface) *SecureRateLimitStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureRateLimitStore{
store: store,
hasher: hasher,
}
}
// GetTokenByID retrieves a token by its ID.
func (s *SecureRateLimitStore) GetTokenByID(ctx context.Context, tokenID string) (token.TokenData, error) {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.GetTokenByID(ctx, hashedToken)
}
// IncrementTokenUsage increments the usage count for a token.
func (s *SecureRateLimitStore) IncrementTokenUsage(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.IncrementTokenUsage(ctx, hashedToken)
}
// ResetTokenUsage resets the usage count for a token to zero.
func (s *SecureRateLimitStore) ResetTokenUsage(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.ResetTokenUsage(ctx, hashedToken)
}
// UpdateTokenLimit updates the maximum allowed requests for a token.
func (s *SecureRateLimitStore) UpdateTokenLimit(ctx context.Context, tokenID string, maxRequests *int) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.UpdateTokenLimit(ctx, hashedToken, maxRequests)
}
// Compile-time interface check
var _ token.RateLimitStore = (*SecureRateLimitStore)(nil)
package eventbus
import (
"context"
"encoding/json"
"net/http"
"sync"
"sync/atomic"
"time"
"log"
"github.com/redis/go-redis/v9"
)
// Event represents an observability event emitted by the proxy.
type Event struct {
LogID int64 // Monotonic event log ID
RequestID string
Method string
Path string
Status int
Duration time.Duration
ResponseHeaders http.Header
ResponseBody []byte
RequestBody []byte
}
// EventBus is a simple interface for publishing events to subscribers.
type EventBus interface {
Publish(ctx context.Context, evt Event)
Subscribe() <-chan Event
Stop()
}
type busStats struct {
published atomic.Int64
dropped atomic.Int64
}
// InMemoryEventBus is an EventBus implementation backed by a buffered channel and
// fan-out broadcasting to multiple subscribers. Events are dispatched
// asynchronously to avoid blocking the request path.
type InMemoryEventBus struct {
bufferSize int
ch chan Event
subsMu sync.RWMutex
subscribers []chan Event
stopCh chan struct{}
wg sync.WaitGroup
retryInterval time.Duration
maxRetries int
stats busStats
}
// NewInMemoryEventBus creates a new in-memory event bus with the given buffer size.
func NewInMemoryEventBus(bufferSize int) *InMemoryEventBus {
b := &InMemoryEventBus{
bufferSize: bufferSize,
ch: make(chan Event, bufferSize),
stopCh: make(chan struct{}),
retryInterval: 10 * time.Millisecond,
maxRetries: 3,
}
b.wg.Add(1)
go b.loop()
return b
}
// Publish sends an event to the bus without blocking if the buffer is full.
func (b *InMemoryEventBus) Publish(ctx context.Context, evt Event) {
select {
case b.ch <- evt:
b.stats.published.Add(1)
default:
b.stats.dropped.Add(1)
}
}
// Subscribe returns a channel that receives events published to the bus.
// Each subscriber receives all events.
func (b *InMemoryEventBus) Subscribe() <-chan Event {
sub := make(chan Event, b.bufferSize)
b.subsMu.Lock()
b.subscribers = append(b.subscribers, sub)
b.subsMu.Unlock()
return sub
}
func (b *InMemoryEventBus) loop() {
defer b.wg.Done()
for {
select {
case evt := <-b.ch:
b.dispatch(evt)
case <-b.stopCh:
b.subsMu.Lock()
for _, sub := range b.subscribers {
close(sub)
}
b.subscribers = nil
b.subsMu.Unlock()
return
}
}
}
func (b *InMemoryEventBus) dispatch(evt Event) {
b.subsMu.RLock()
subs := append([]chan Event(nil), b.subscribers...)
b.subsMu.RUnlock()
for _, sub := range subs {
sent := false
for i := 0; i <= b.maxRetries; i++ {
select {
case sub <- evt:
sent = true
default:
time.Sleep(b.retryInterval * time.Duration(i+1))
}
if sent {
break
}
}
}
}
// Stop gracefully stops the event bus and closes all subscriber channels.
func (b *InMemoryEventBus) Stop() {
close(b.stopCh)
b.wg.Wait()
}
// Stats returns the number of published and dropped events.
func (b *InMemoryEventBus) Stats() (published, dropped int) {
return int(b.stats.published.Load()), int(b.stats.dropped.Load())
}
// Extend RedisClient interface for LRANGE, LLEN, EXPIRE, LTRIM, Incr, Get, Set
type RedisClient interface {
LPush(ctx context.Context, key string, values ...interface{}) error
LRANGE(ctx context.Context, key string, start, stop int64) ([]string, error)
LLEN(ctx context.Context, key string) (int64, error)
EXPIRE(ctx context.Context, key string, expiration time.Duration) error
LTRIM(ctx context.Context, key string, start, stop int64) error
Incr(ctx context.Context, key string) (int64, error)
Get(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key, value string) error
}
// RedisGoClientAdapter adapts go-redis/v9 Client to the RedisClient interface.
type RedisGoClientAdapter struct {
Client *redis.Client
}
// Extend RedisGoClientAdapter to implement new methods
func (a *RedisGoClientAdapter) LRANGE(ctx context.Context, key string, start, stop int64) ([]string, error) {
return a.Client.LRange(ctx, key, start, stop).Result()
}
func (a *RedisGoClientAdapter) LLEN(ctx context.Context, key string) (int64, error) {
return a.Client.LLen(ctx, key).Result()
}
func (a *RedisGoClientAdapter) EXPIRE(ctx context.Context, key string, expiration time.Duration) error {
return a.Client.Expire(ctx, key, expiration).Err()
}
func (a *RedisGoClientAdapter) LTRIM(ctx context.Context, key string, start, stop int64) error {
return a.Client.LTrim(ctx, key, start, stop).Err()
}
func (a *RedisGoClientAdapter) LPush(ctx context.Context, key string, values ...interface{}) error {
return a.Client.LPush(ctx, key, values...).Err()
}
func (a *RedisGoClientAdapter) Incr(ctx context.Context, key string) (int64, error) {
return a.Client.Incr(ctx, key).Result()
}
func (a *RedisGoClientAdapter) Get(ctx context.Context, key string) (string, error) {
return a.Client.Get(ctx, key).Result()
}
func (a *RedisGoClientAdapter) Set(ctx context.Context, key, value string) error {
return a.Client.Set(ctx, key, value, 0).Err()
}
// Refactor RedisEventBus: remove BRPOP/loop, add non-destructive read
// Remove NewRedisEventBusSubscriber and loop()
// Add ReadEvents and SetTTL methods
// NewRedisEventBusLog creates a Redis event bus that acts as a persistent log (non-destructive, with TTL and optional max length)
func NewRedisEventBusLog(client RedisClient, key string, ttl time.Duration, maxLen int64) *RedisEventBus {
return &RedisEventBus{
client: client,
key: key,
logTTL: ttl,
maxLen: maxLen,
}
}
// RedisEventBus is a Redis-backed EventBus implementation. Events are encoded as
// JSON and pushed to a Redis list. This version is a persistent log: events are never removed by consumers.
type RedisEventBus struct {
client RedisClient
key string
stats busStats
logTTL time.Duration // TTL for the Redis list
maxLen int64 // Max length for the Redis list
}
// NewRedisEventBusPublisher creates a Redis event bus that only publishes events (no background consumer).
func NewRedisEventBusPublisher(client RedisClient, key string) *RedisEventBus {
return &RedisEventBus{
client: client,
key: key,
}
}
// Publish pushes the event JSON to the Redis list.
func (b *RedisEventBus) Publish(ctx context.Context, evt Event) {
// Assign a monotonic LogID
seq, err := b.client.Incr(ctx, b.key+":seq")
if err != nil {
log.Printf("[eventbus] Failed to increment event log sequence: %v", err)
b.stats.dropped.Add(1)
return
}
evt.LogID = seq
data, err := json.Marshal(evt)
if err != nil {
log.Printf("[eventbus] Failed to marshal event: %v", err)
return
}
if err := b.client.LPush(ctx, b.key, data); err != nil {
log.Printf("[eventbus] Failed to publish event to Redis key %s: %v", b.key, err)
b.stats.dropped.Add(1)
return
}
if b.maxLen > 0 {
_ = b.client.LTRIM(ctx, b.key, 0, b.maxLen-1)
}
if b.logTTL > 0 {
_ = b.client.EXPIRE(ctx, b.key, b.logTTL)
}
b.stats.published.Add(1)
}
// ReadEvents returns events in [start, end] (inclusive, like LRANGE)
func (b *RedisEventBus) ReadEvents(ctx context.Context, start, end int64) ([]Event, error) {
items, err := b.client.LRANGE(ctx, b.key, start, end)
if err != nil {
return nil, err
}
var events []Event
for _, item := range items {
var evt Event
if err := json.Unmarshal([]byte(item), &evt); err == nil {
events = append(events, evt)
}
}
return events, nil
}
// EventCount returns the current number of events in the log
func (b *RedisEventBus) EventCount(ctx context.Context) (int64, error) {
return b.client.LLEN(ctx, b.key)
}
// Stop is a no-op for the log-based RedisEventBus (required to satisfy EventBus interface)
func (b *RedisEventBus) Stop() {}
// Subscribe is not supported for the log-based RedisEventBus. It returns a closed channel.
func (b *RedisEventBus) Subscribe() <-chan Event {
ch := make(chan Event)
close(ch)
return ch
}
// Client returns the underlying RedisClient for this RedisEventBus
func (b *RedisEventBus) Client() RedisClient {
return b.client
}
// Note: Test-only RedisClient mocks have been moved into test files to avoid
// mixing testing helpers with production code and to keep coverage meaningful.
package eventbus
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
)
// RedisStreamsClient interface for Redis Streams operations.
// This abstraction allows for easy mocking in tests.
type RedisStreamsClient interface {
// XAdd adds an entry to a stream
XAdd(ctx context.Context, args *redis.XAddArgs) (string, error)
// XReadGroup reads entries from a stream using a consumer group
XReadGroup(ctx context.Context, args *redis.XReadGroupArgs) ([]redis.XStream, error)
// XAck acknowledges processed messages
XAck(ctx context.Context, stream, group string, ids ...string) (int64, error)
// XGroupCreateMkStream creates a consumer group (and the stream if needed)
XGroupCreateMkStream(ctx context.Context, stream, group, start string) error
// XPending returns pending entries for a consumer group
XPending(ctx context.Context, stream, group string) (*redis.XPending, error)
// XPendingExt returns detailed pending entries
XPendingExt(ctx context.Context, args *redis.XPendingExtArgs) ([]redis.XPendingExt, error)
// XClaim claims pending messages for a consumer
XClaim(ctx context.Context, args *redis.XClaimArgs) ([]redis.XMessage, error)
// XLen returns the length of a stream
XLen(ctx context.Context, stream string) (int64, error)
// XInfoGroups returns consumer group info for a stream
XInfoGroups(ctx context.Context, stream string) ([]redis.XInfoGroup, error)
}
// RedisStreamsClientAdapter adapts go-redis/v9 Client to the RedisStreamsClient interface.
type RedisStreamsClientAdapter struct {
Client *redis.Client
}
// XAdd adds an entry to a stream
func (a *RedisStreamsClientAdapter) XAdd(ctx context.Context, args *redis.XAddArgs) (string, error) {
return a.Client.XAdd(ctx, args).Result()
}
// XReadGroup reads entries from a stream using a consumer group
func (a *RedisStreamsClientAdapter) XReadGroup(ctx context.Context, args *redis.XReadGroupArgs) ([]redis.XStream, error) {
return a.Client.XReadGroup(ctx, args).Result()
}
// XAck acknowledges processed messages
func (a *RedisStreamsClientAdapter) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
return a.Client.XAck(ctx, stream, group, ids...).Result()
}
// XGroupCreateMkStream creates a consumer group (and the stream if needed)
func (a *RedisStreamsClientAdapter) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
return a.Client.XGroupCreateMkStream(ctx, stream, group, start).Err()
}
// XPending returns pending entries for a consumer group
func (a *RedisStreamsClientAdapter) XPending(ctx context.Context, stream, group string) (*redis.XPending, error) {
return a.Client.XPending(ctx, stream, group).Result()
}
// XPendingExt returns detailed pending entries
func (a *RedisStreamsClientAdapter) XPendingExt(ctx context.Context, args *redis.XPendingExtArgs) ([]redis.XPendingExt, error) {
return a.Client.XPendingExt(ctx, args).Result()
}
// XClaim claims pending messages for a consumer
func (a *RedisStreamsClientAdapter) XClaim(ctx context.Context, args *redis.XClaimArgs) ([]redis.XMessage, error) {
return a.Client.XClaim(ctx, args).Result()
}
// XLen returns the length of a stream
func (a *RedisStreamsClientAdapter) XLen(ctx context.Context, stream string) (int64, error) {
return a.Client.XLen(ctx, stream).Result()
}
// XInfoGroups returns consumer group info for a stream
func (a *RedisStreamsClientAdapter) XInfoGroups(ctx context.Context, stream string) ([]redis.XInfoGroup, error) {
return a.Client.XInfoGroups(ctx, stream).Result()
}
// RedisStreamsConfig holds configuration for Redis Streams event bus.
type RedisStreamsConfig struct {
StreamKey string // Redis stream key name
ConsumerGroup string // Consumer group name
ConsumerName string // Unique consumer name within the group
MaxLen int64 // Max stream length (0 = unlimited, uses MAXLEN ~ approximation)
BlockTimeout time.Duration // Block timeout for XREADGROUP (0 = non-blocking)
ClaimMinIdleTime time.Duration // Minimum idle time before claiming pending messages
BatchSize int64 // Number of messages to read at once
}
// DefaultRedisStreamsConfig returns default configuration.
func DefaultRedisStreamsConfig() RedisStreamsConfig {
return RedisStreamsConfig{
StreamKey: "llm-proxy-events",
ConsumerGroup: "llm-proxy-dispatchers",
ConsumerName: "dispatcher-1",
MaxLen: 10000,
BlockTimeout: 5 * time.Second,
ClaimMinIdleTime: 30 * time.Second,
BatchSize: 100,
}
}
// RedisStreamsEventBus implements EventBus using Redis Streams.
// It provides durable, distributed event delivery with consumer groups,
// acknowledgment, and at-least-once delivery semantics.
type RedisStreamsEventBus struct {
client RedisStreamsClient
config RedisStreamsConfig
stats busStats
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
subscribers []chan Event
subsMu sync.RWMutex
groupCreated atomic.Bool
}
// NewRedisStreamsEventBus creates a new Redis Streams event bus.
func NewRedisStreamsEventBus(client RedisStreamsClient, config RedisStreamsConfig) *RedisStreamsEventBus {
return &RedisStreamsEventBus{
client: client,
config: config,
stopCh: make(chan struct{}),
}
}
// EnsureConsumerGroup creates the consumer group if it doesn't exist.
// This should be called before starting to consume messages.
func (b *RedisStreamsEventBus) EnsureConsumerGroup(ctx context.Context) error {
if b.groupCreated.Load() {
return nil
}
// Try to create the group; if it already exists, that's fine
err := b.client.XGroupCreateMkStream(ctx, b.config.StreamKey, b.config.ConsumerGroup, "0")
if err != nil {
// Check if error is because group already exists
if isGroupExistsError(err) {
b.groupCreated.Store(true)
return nil
}
return fmt.Errorf("failed to create consumer group: %w", err)
}
b.groupCreated.Store(true)
return nil
}
// isGroupExistsError checks if the error indicates the group already exists.
func isGroupExistsError(err error) bool {
if err == nil {
return false
}
// Redis returns "BUSYGROUP Consumer Group name already exists" error
return err.Error() == "BUSYGROUP Consumer Group name already exists"
}
// Publish adds an event to the Redis stream using XADD.
func (b *RedisStreamsEventBus) Publish(ctx context.Context, evt Event) {
data, err := json.Marshal(evt)
if err != nil {
log.Printf("[eventbus] Failed to marshal event: %v", err)
b.stats.dropped.Add(1)
return
}
args := &redis.XAddArgs{
Stream: b.config.StreamKey,
Values: map[string]interface{}{
"data": string(data),
},
}
// Apply MaxLen if configured
if b.config.MaxLen > 0 {
args.MaxLen = b.config.MaxLen
args.Approx = true // Use ~ for better performance
}
_, err = b.client.XAdd(ctx, args)
if err != nil {
log.Printf("[eventbus] Failed to publish event to stream %s: %v", b.config.StreamKey, err)
b.stats.dropped.Add(1)
return
}
b.stats.published.Add(1)
}
// Subscribe returns a channel that receives events from the stream.
// This starts a background goroutine that reads from the stream using consumer groups.
func (b *RedisStreamsEventBus) Subscribe() <-chan Event {
ch := make(chan Event, b.config.BatchSize)
b.subsMu.Lock()
b.subscribers = append(b.subscribers, ch)
b.subsMu.Unlock()
b.wg.Add(1)
go b.consumeLoop(ch)
return ch
}
// consumeLoop reads messages from the stream and dispatches to the subscriber channel.
func (b *RedisStreamsEventBus) consumeLoop(ch chan Event) {
defer b.wg.Done()
defer close(ch)
ctx := context.Background()
// Ensure consumer group exists
if err := b.EnsureConsumerGroup(ctx); err != nil {
log.Printf("[eventbus] Failed to ensure consumer group: %v", err)
return
}
// First, process any pending messages (messages we received but didn't acknowledge)
b.processPendingMessages(ctx, ch)
// Then start normal consumption
for {
select {
case <-b.stopCh:
return
default:
}
// Read new messages
streams, err := b.client.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: b.config.ConsumerGroup,
Consumer: b.config.ConsumerName,
Streams: []string{b.config.StreamKey, ">"},
Count: b.config.BatchSize,
Block: b.config.BlockTimeout,
})
if err != nil {
if err == redis.Nil {
// No new messages, check for pending messages to claim
b.claimPendingMessages(ctx, ch)
continue
}
// Check if context was cancelled or we're stopping
select {
case <-b.stopCh:
return
default:
log.Printf("[eventbus] Error reading from stream: %v", err)
time.Sleep(time.Second) // Back off on error
continue
}
}
// Process received messages
for _, stream := range streams {
for _, msg := range stream.Messages {
evt, err := b.parseMessage(msg)
if err != nil {
log.Printf("[eventbus] Failed to parse message %s: %v", msg.ID, err)
// Acknowledge invalid messages so they don't get stuck
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
continue
}
// Try to send to subscriber
select {
case ch <- evt:
// Message delivered, acknowledge it
_, err := b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
if err != nil {
log.Printf("[eventbus] Failed to acknowledge message %s: %v", msg.ID, err)
}
case <-b.stopCh:
return
}
}
}
}
}
// processPendingMessages handles messages that were delivered but not acknowledged.
// This is called on startup to handle messages from a previous crash.
func (b *RedisStreamsEventBus) processPendingMessages(ctx context.Context, ch chan Event) {
// Read our pending messages
streams, err := b.client.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: b.config.ConsumerGroup,
Consumer: b.config.ConsumerName,
Streams: []string{b.config.StreamKey, "0"}, // "0" means pending messages
Count: b.config.BatchSize,
})
if err != nil && err != redis.Nil {
log.Printf("[eventbus] Error reading pending messages: %v", err)
return
}
for _, stream := range streams {
for _, msg := range stream.Messages {
evt, err := b.parseMessage(msg)
if err != nil {
log.Printf("[eventbus] Failed to parse pending message %s: %v", msg.ID, err)
// Acknowledge invalid messages
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
continue
}
select {
case ch <- evt:
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
case <-b.stopCh:
return
}
}
}
}
// claimPendingMessages claims messages from other consumers that have been idle too long.
// This implements the "at-least-once" delivery guarantee for crashed consumers.
func (b *RedisStreamsEventBus) claimPendingMessages(ctx context.Context, ch chan Event) {
// Get pending entries that have exceeded the idle time
pending, err := b.client.XPendingExt(ctx, &redis.XPendingExtArgs{
Stream: b.config.StreamKey,
Group: b.config.ConsumerGroup,
Start: "-",
End: "+",
Count: b.config.BatchSize,
})
if err != nil || len(pending) == 0 {
return
}
// Find messages that are idle long enough to claim
var toClaim []string
for _, p := range pending {
if p.Idle >= b.config.ClaimMinIdleTime {
toClaim = append(toClaim, p.ID)
}
}
if len(toClaim) == 0 {
return
}
// Claim the messages
messages, err := b.client.XClaim(ctx, &redis.XClaimArgs{
Stream: b.config.StreamKey,
Group: b.config.ConsumerGroup,
Consumer: b.config.ConsumerName,
MinIdle: b.config.ClaimMinIdleTime,
Messages: toClaim,
})
if err != nil {
log.Printf("[eventbus] Error claiming messages: %v", err)
return
}
// Process claimed messages
for _, msg := range messages {
evt, err := b.parseMessage(msg)
if err != nil {
log.Printf("[eventbus] Failed to parse claimed message %s: %v", msg.ID, err)
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
continue
}
select {
case ch <- evt:
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
case <-b.stopCh:
return
}
}
}
// parseMessage extracts an Event from a Redis stream message.
func (b *RedisStreamsEventBus) parseMessage(msg redis.XMessage) (Event, error) {
var evt Event
data, ok := msg.Values["data"]
if !ok {
return evt, fmt.Errorf("message missing 'data' field")
}
dataStr, ok := data.(string)
if !ok {
return evt, fmt.Errorf("'data' field is not a string")
}
if err := json.Unmarshal([]byte(dataStr), &evt); err != nil {
return evt, fmt.Errorf("failed to unmarshal event: %w", err)
}
return evt, nil
}
// Stop gracefully stops the event bus and closes all subscriber channels.
func (b *RedisStreamsEventBus) Stop() {
b.stopOnce.Do(func() {
close(b.stopCh)
b.wg.Wait()
})
}
// Stats returns the number of published and dropped events.
func (b *RedisStreamsEventBus) Stats() (published, dropped int) {
return int(b.stats.published.Load()), int(b.stats.dropped.Load())
}
// StreamLength returns the current length of the stream.
func (b *RedisStreamsEventBus) StreamLength(ctx context.Context) (int64, error) {
return b.client.XLen(ctx, b.config.StreamKey)
}
// PendingCount returns the number of pending messages in the consumer group.
func (b *RedisStreamsEventBus) PendingCount(ctx context.Context) (int64, error) {
pending, err := b.client.XPending(ctx, b.config.StreamKey, b.config.ConsumerGroup)
if err != nil {
return 0, err
}
return pending.Count, nil
}
// Client returns the underlying RedisStreamsClient.
func (b *RedisStreamsEventBus) Client() RedisStreamsClient {
return b.client
}
// Acknowledge manually acknowledges a message by ID.
// This is useful when external code handles message processing and acknowledgment.
func (b *RedisStreamsEventBus) Acknowledge(ctx context.Context, messageID string) error {
_, err := b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, messageID)
return err
}
package eventtransformer
import (
"bytes"
"compress/gzip"
"encoding/base64"
"encoding/json"
"io"
"strings"
"unicode/utf8"
"github.com/andybalholm/brotli"
)
// DecompressAndDecode attempts to decompress (gzip, brotli) if needed, then base64 decode if needed, and returns the decoded string and true if decoding was successful.
func DecompressAndDecode(val string, headers map[string]interface{}) (string, bool) {
// Only log errors, binary skipping, and major state changes
data := []byte(val)
encoding := ""
contentType := ""
for k, v := range headers {
key := strings.ToLower(strings.ReplaceAll(k, "-", "_"))
if key == "content_encoding" {
if arr, ok := v.([]interface{}); ok && len(arr) > 0 {
if s, ok := arr[0].(string); ok {
encoding = strings.ToLower(s)
}
} else if s, ok := v.(string); ok {
encoding = strings.ToLower(s)
}
}
if key == "content_type" {
if arr, ok := v.([]interface{}); ok && len(arr) > 0 {
if s, ok := arr[0].(string); ok {
contentType = strings.ToLower(s)
}
} else if s, ok := v.(string); ok {
contentType = strings.ToLower(s)
}
}
}
// Skip binary content types (audio, image, octet-stream)
if strings.HasPrefix(contentType, "audio/") || strings.HasPrefix(contentType, "image/") || contentType == "application/octet-stream" {
// Skipping decode for binary content-type
return val, false
}
// 1. Try base64 decode (standard)
decoded, err := base64.StdEncoding.DecodeString(string(data))
if err == nil {
data = decoded
// Use tagged switch for encoding
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader(data))
if err == nil {
decompressed, err := io.ReadAll(zr)
_ = zr.Close()
if err == nil {
data = decompressed
}
}
case "br":
br := brotli.NewReader(bytes.NewReader(data))
decompressed, err := io.ReadAll(br)
if err == nil {
data = decompressed
}
}
if json.Valid(data) {
return string(data), true
}
if utf8.Valid(data) {
return string(data), true
}
}
// 2. Try base64 decode (URL-safe)
decoded, err = base64.URLEncoding.DecodeString(string(data))
if err == nil {
data = decoded
// Use tagged switch for encoding
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader(data))
if err == nil {
decompressed, err := io.ReadAll(zr)
_ = zr.Close()
if err == nil {
data = decompressed
}
}
case "br":
br := brotli.NewReader(bytes.NewReader(data))
decompressed, err := io.ReadAll(br)
if err == nil {
data = decompressed
}
}
if json.Valid(data) {
return string(data), true
}
if utf8.Valid(data) {
return string(data), true
}
}
// 3. If base64 fails, try decompressing original data (legacy case)
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader([]byte(val)))
if err == nil {
decompressed, err := io.ReadAll(zr)
_ = zr.Close()
if err == nil {
if json.Valid(decompressed) {
return string(decompressed), true
}
if utf8.Valid(decompressed) {
return string(decompressed), true
}
}
}
case "br":
br := brotli.NewReader(bytes.NewReader([]byte(val)))
decompressed, err := io.ReadAll(br)
if err == nil {
if json.Valid(decompressed) {
return string(decompressed), true
}
if utf8.Valid(decompressed) {
return string(decompressed), true
}
}
}
// 4. Fallback: check if input is JSON or UTF-8
if json.Valid([]byte(val)) {
return val, true
}
if utf8.Valid([]byte(val)) {
return val, true
}
// Fallback: returning original input, could not decode
return val, false
}
package eventtransformer
import (
"encoding/base64"
"encoding/json"
"os"
"strings"
"unicode/utf8"
"github.com/google/uuid"
)
// IsOpenAIStreaming detects if the response body is a sequence of OpenAI streaming chunks (data: ... lines).
func IsOpenAIStreaming(body string) bool {
lines := strings.Split(body, "\n")
count := 0
for _, line := range lines {
if strings.HasPrefix(line, "data: ") && !strings.HasPrefix(line, "data: [DONE]") {
count++
}
}
return count > 1 // at least two chunks
}
// MergeOpenAIStreamingChunks parses and merges OpenAI streaming chunks into a single response object.
func MergeOpenAIStreamingChunks(body string) (map[string]any, error) {
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
var (
id string
object string
created int
model string
content strings.Builder
finish string
usage Usage
)
lines := strings.Split(body, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
data := strings.TrimPrefix(line, "data: ")
var chunk map[string]any
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if v, ok := chunk["id"].(string); ok && id == "" {
id = v
}
if v, ok := chunk["object"].(string); ok && object == "" {
object = v
}
if v, ok := chunk["created"].(float64); ok && created == 0 {
created = int(v)
}
if v, ok := chunk["model"].(string); ok && model == "" {
model = v
}
// Merge choices
if choices, ok := chunk["choices"].([]any); ok && len(choices) > 0 {
choice := choices[0].(map[string]any)
if delta, ok := choice["delta"].(map[string]any); ok {
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
if fr, ok := choice["finish_reason"].(string); ok && fr != "" {
finish = fr
}
}
// Merge usage if present (usually only in last chunk)
if u, ok := chunk["usage"].(map[string]any); ok {
if v, ok := u["prompt_tokens"].(float64); ok {
usage.PromptTokens = int(v)
}
if v, ok := u["completion_tokens"].(float64); ok {
usage.CompletionTokens = int(v)
}
if v, ok := u["total_tokens"].(float64); ok {
usage.TotalTokens = int(v)
}
}
}
merged := map[string]any{
"id": id,
"object": object,
"created": created,
"model": model,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": content.String(),
},
"finish_reason": finish,
}},
}
if usage.PromptTokens > 0 || usage.CompletionTokens > 0 || usage.TotalTokens > 0 {
merged["usage"] = map[string]any{
"prompt_tokens": usage.PromptTokens,
"completion_tokens": usage.CompletionTokens,
"total_tokens": usage.TotalTokens,
}
}
// Merged OpenAI chunks successfully
return merged, nil
}
// MergeThreadStreamingChunks parses and merges assistant thread.run streaming events.
func MergeThreadStreamingChunks(body string) (map[string]any, error) {
type Usage struct {
PromptTokens int
CompletionTokens int
TotalTokens int
}
var (
id string
assistantID string
threadID string
status string
created int
model string
contentB strings.Builder
usage Usage
finalContent string // for thread.message.completed
)
lines := strings.Split(body, "\n")
foundCompleted := false
for i := 0; i < len(lines); i++ {
line := strings.TrimSpace(lines[i])
if strings.HasPrefix(line, "event: ") {
eventType := strings.TrimPrefix(line, "event: ")
i++
if i >= len(lines) {
break
}
dataLine := strings.TrimSpace(lines[i])
if !strings.HasPrefix(dataLine, "data: ") || dataLine == "data: [DONE]" {
continue
}
var msg map[string]any
if err := json.Unmarshal([]byte(strings.TrimPrefix(dataLine, "data: ")), &msg); err != nil {
continue
}
switch eventType {
case "thread.run.created", "thread.run.queued":
if v, ok := msg["id"].(string); ok && id == "" {
id = v
}
if v, ok := msg["assistant_id"].(string); ok && assistantID == "" {
assistantID = v
}
if v, ok := msg["thread_id"].(string); ok && threadID == "" {
threadID = v
}
if v, ok := msg["status"].(string); ok && status == "" {
status = v
}
if v, ok := msg["created_at"].(float64); ok && created == 0 {
created = int(v)
}
case "thread.message.delta":
if !foundCompleted {
if delta, ok := msg["delta"].(map[string]any); ok {
if segs, ok := delta["content"].([]any); ok {
for _, seg := range segs {
segMap := seg.(map[string]any)
if txt, ok := segMap["text"].(map[string]any); ok {
contentB.WriteString(txt["value"].(string))
}
}
}
}
}
case "thread.message.completed":
// If this event is found, ignore all previous deltas and use only this content
if contentArr, ok := msg["content"].([]any); ok {
var sb strings.Builder
for _, seg := range contentArr {
segMap, ok := seg.(map[string]any)
if !ok {
continue
}
if segMap["type"] == "text" {
if txt, ok := segMap["text"].(map[string]any); ok {
if val, ok := txt["value"].(string); ok {
sb.WriteString(val)
}
}
}
}
finalContent = sb.String()
foundCompleted = true
}
case "thread.run.step.completed", "thread.run.completed":
if u, ok := msg["usage"].(map[string]any); ok {
if v, ok := u["prompt_tokens"].(float64); ok {
usage.PromptTokens = int(v)
}
if v, ok := u["completion_tokens"].(float64); ok {
usage.CompletionTokens = int(v)
}
if v, ok := u["total_tokens"].(float64); ok {
usage.TotalTokens = int(v)
}
}
}
}
}
// If thread.message.completed was found, use only its content
if foundCompleted {
contentB.Reset()
contentB.WriteString(finalContent)
}
merged := map[string]any{
"id": id,
"object": "thread.run",
"created_at": created,
"assistant_id": assistantID,
"thread_id": threadID,
"status": status,
"model": model,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": contentB.String(),
},
"finish_reason": "",
}},
}
if usage.PromptTokens+usage.CompletionTokens+usage.TotalTokens > 0 {
merged["usage"] = map[string]any{
"prompt_tokens": usage.PromptTokens,
"completion_tokens": usage.CompletionTokens,
"total_tokens": usage.TotalTokens,
}
}
// Merged thread.run chunks successfully
return merged, nil
}
// TransformEvent transforms an OpenAI event for logging/analytics.
// It handles OPTIONS skipping, header filtering, decoding, chunk merging, token counting, and snake_case normalization.
func (t *OpenAITransformer) TransformEvent(evt map[string]any) (map[string]any, error) {
// Skip OPTIONS requests
if method, _ := evt["Method"].(string); strings.ToUpper(method) == "OPTIONS" {
return nil, nil
}
// Filter out set-cookie from response_headers
if headers, ok := evt["ResponseHeaders"].(map[string]any); ok {
for k := range headers {
if strings.ToLower(k) == "set-cookie" {
delete(headers, k)
}
}
}
// Set request_id
requestID := ""
if headers, ok := evt["RequestHeaders"].(map[string]any); ok {
for k, v := range headers {
if strings.ToLower(k) == "x-request-id" {
if s, ok := v.(string); ok && s != "" {
requestID = s
break
}
}
}
}
if requestID == "" {
requestID = uuid.NewString()
}
evt["request_id"] = requestID
// Decode request_body to UTF-8 or compact JSON
for _, key := range []string{"RequestBody", "request_body"} {
if v, ok := evt[key]; ok {
switch val := v.(type) {
case string:
if val != "" {
decoded := tryBase64DecodeWithLog(val)
if compact, _, ok := normalizeToCompactJSON(decoded); ok {
evt["request_body"] = compact
// promptTokenSource is set below if request_body is present
} else if isValidUTF8(decoded) {
evt["request_body"] = decoded
} else {
evt["request_body"] = "[binary or undecodable data]"
}
break
}
case []byte:
if len(val) > 0 {
decoded := tryBase64DecodeWithLog(string(val))
evt["request_body"] = decoded
break
}
}
}
}
// Handle ResponseBody
contentType := ""
hdrs, ok := evt["ResponseHeaders"]
var hdrMap map[string]any
if ok {
hdrMap, _ = hdrs.(map[string]any)
} else {
hdrMap = map[string]any{}
}
if hdrMap == nil {
hdrMap = map[string]any{}
}
if len(hdrMap) > 0 {
for k, v := range hdrMap {
if strings.ToLower(strings.ReplaceAll(k, "-", "_")) == "content_type" {
switch arr := v.(type) {
case []any:
if len(arr) > 0 {
if s, ok := arr[0].(string); ok {
contentType = strings.ToLower(s)
}
}
case string:
contentType = strings.ToLower(arr)
}
}
}
}
if respBody, ok := evt["ResponseBody"].(string); ok && respBody != "" {
decoded, okDecoded := DecompressAndDecode(respBody, hdrMap)
if !okDecoded {
decoded = tryBase64DecodeWithLog(respBody)
}
if strings.HasPrefix(contentType, "audio/") || strings.HasPrefix(contentType, "image/") || contentType == "application/octet-stream" {
if os.Getenv("LOG_BINARY_RESPONSES") == "1" {
evt["response_body"] = respBody
} else {
evt["response_body"] = "[binary or undecodable data]"
}
evt["response_body_binary"] = true
} else {
if IsOpenAIStreaming(decoded) {
var merged map[string]any
var err error
if strings.Contains(decoded, "event: thread.run") {
merged, err = MergeThreadStreamingChunks(decoded)
} else {
merged, err = MergeOpenAIStreamingChunks(decoded)
}
if err == nil {
comp, _ := json.Marshal(merged)
evt["response_body"] = string(comp)
if usage, ok := merged["usage"].(map[string]any); ok {
evt["TokenUsage"] = usage
}
} else {
// Merge stream error - using decoded response body
evt["response_body"] = decoded
}
} else {
if compact, _, ok := normalizeToCompactJSON(decoded); ok {
evt["response_body"] = compact
} else if isValidUTF8(decoded) {
evt["response_body"] = decoded
} else {
evt["response_body"] = "[binary or undecodable data]"
}
// Extract usage from non-streaming completion
if resp, ok := evt["response_body"].(string); ok && json.Valid([]byte(resp)) {
var respObj map[string]any
if err := json.Unmarshal([]byte(resp), &respObj); err == nil {
if usage, ok := respObj["usage"].(map[string]any); ok {
evt["TokenUsage"] = usage
delete(respObj, "usage")
b, _ := json.Marshal(respObj)
evt["response_body"] = string(b)
}
}
}
}
}
}
// Token usage fallback
if _, has := evt["TokenUsage"]; !has {
if resp, ok := evt["response_body"].(string); ok && json.Valid([]byte(resp)) {
pt, ct := 0, 0
// Compose prompt token source: messages + instructions if present
var promptTokenSource string
if req, ok := evt["request_body"].(string); ok && req != "" {
var reqObj map[string]any
if err := json.Unmarshal([]byte(req), &reqObj); err == nil {
if msgs, ok := reqObj["messages"]; ok {
b, _ := json.Marshal(msgs)
promptTokenSource = string(b)
}
if instr, ok := reqObj["instructions"].(string); ok && instr != "" {
if promptTokenSource != "" {
promptTokenSource += instr
} else {
promptTokenSource = instr
}
}
}
}
if promptTokenSource != "" {
modelName := ""
if req, ok := evt["request_body"].(string); ok && req != "" {
var reqObj map[string]any
_ = json.Unmarshal([]byte(req), &reqObj)
if m, ok := reqObj["model"].(string); ok {
modelName = m
}
}
t, _ := CountOpenAITokensForModel(promptTokenSource, modelName)
pt = t
}
cnt, _ := extractAssistantReplyContent(resp)
if cnt != "" {
// Try to read model from the parsed response JSON
modelName := ""
var respObj map[string]any
if err := json.Unmarshal([]byte(resp), &respObj); err == nil {
if m, ok := respObj["model"].(string); ok {
modelName = m
}
}
tk, _ := CountOpenAITokensForModel(cnt, modelName)
ct = tk
}
evt["TokenUsage"] = map[string]int{"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct}
}
}
// Clean up and normalize
delete(evt, "RequestBody")
delete(evt, "ResponseBody")
delete(evt, "response_body_streamed")
return ToSnakeCaseMap(evt), nil
}
// -- helper functions --
func tryBase64DecodeWithLog(val string) string {
clean := strings.ReplaceAll(val, "", "")
if json.Valid([]byte(clean)) {
return clean
}
if b, err := base64.StdEncoding.DecodeString(clean); err == nil {
return string(b)
}
if b, err := base64.URLEncoding.DecodeString(clean); err == nil {
return string(b)
}
return clean
}
func normalizeToCompactJSON(input string) (string, string, bool) {
var obj any
if err := json.Unmarshal([]byte(input), &obj); err != nil {
return input, "", false
}
if s, ok := obj.(string); ok {
var inner any
if err := json.Unmarshal([]byte(s), &inner); err == nil {
b, _ := json.Marshal(inner)
str := string(b)
if m, ok := inner.(map[string]any); ok {
if msgs, ok := m["messages"]; ok {
mj, _ := json.Marshal(msgs)
return str, string(mj), true
}
}
return str, "", true
}
return s, "", true
}
b, err := json.Marshal(obj)
if err != nil {
return input, "", false
}
str := string(b)
if m, ok := obj.(map[string]any); ok {
if msgs, ok := m["messages"]; ok {
mj, _ := json.Marshal(msgs)
return str, string(mj), true
}
}
return str, "", true
}
func isValidUTF8(s string) bool { return utf8.ValidString(s) }
func extractAssistantReplyContent(resp string) (string, error) {
var obj map[string]any
if err := json.Unmarshal([]byte(resp), &obj); err != nil {
return "", err
}
if ch, ok := obj["choices"].([]any); ok && len(ch) > 0 {
if c0, ok := ch[0].(map[string]any); ok {
if msg, ok := c0["message"].(map[string]any); ok {
if content, ok := msg["content"].(string); ok {
return content, nil
}
}
}
}
return "", nil
}
package eventtransformer
import (
"strings"
"unicode"
)
// ToSnakeCaseMap recursively converts all map keys to snake_case, replacing dashes and handling consecutive uppercase letters
func ToSnakeCaseMap(m map[string]interface{}) map[string]interface{} {
snake := make(map[string]interface{}, len(m))
for k, v := range m {
sk := ToSnakeCase(k)
switch vv := v.(type) {
case map[string]interface{}:
snake[sk] = ToSnakeCaseMap(vv)
case []interface{}:
arr := make([]interface{}, len(vv))
for i, elem := range vv {
if mm, ok := elem.(map[string]interface{}); ok {
arr[i] = ToSnakeCaseMap(mm)
} else {
arr[i] = elem
}
}
snake[sk] = arr
default:
snake[sk] = v
}
}
return snake
}
// ToSnakeCase converts a string from CamelCase, PascalCase, or kebab-case to snake_case
func ToSnakeCase(s string) string {
// Replace dashes with underscores
s = strings.ReplaceAll(s, "-", "_")
var out []rune
var prevLower, prevUnderscore bool
for i, r := range s {
if r == '_' {
out = append(out, r)
prevLower, prevUnderscore = false, true
continue
}
if unicode.IsUpper(r) {
if (i > 0 && prevLower) || (i > 0 && !prevUnderscore && i+1 < len(s) && unicode.IsLower(rune(s[i+1]))) {
out = append(out, '_')
}
out = append(out, unicode.ToLower(r))
prevLower, prevUnderscore = false, false
} else {
out = append(out, r)
prevLower, prevUnderscore = true, false
}
}
return string(out)
}
package eventtransformer
import (
"strings"
"github.com/pkoukk/tiktoken-go"
)
// CountOpenAITokens counts tokens using a general-purpose encoding.
// Note: Prefer CountOpenAITokensForModel when the model is known.
func CountOpenAITokens(text string) (int, error) {
enc, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return 0, err
}
return len(enc.Encode(text, nil, nil)), nil
}
// CountOpenAITokensForModel selects an encoding based on the provided model name.
// Fallback rules:
// - 4o/omni/o1 family → o200k_base
// - otherwise → cl100k_base
// If EncodingForModel succeeds, it is used directly.
func CountOpenAITokensForModel(text, model string) (int, error) {
if model != "" {
if enc, err := tiktoken.EncodingForModel(model); err == nil {
return len(enc.Encode(text, nil, nil)), nil
}
}
// Heuristic fallback by family
var base string
lower := strings.ToLower(model)
switch {
case strings.Contains(lower, "gpt-4o"), strings.HasPrefix(lower, "o1"), strings.Contains(lower, "omni"):
base = "o200k_base"
default:
base = "cl100k_base"
}
enc, err := tiktoken.GetEncoding(base)
if err != nil {
// Last resort: try cl100k_base
enc, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return 0, err
}
}
return len(enc.Encode(text, nil, nil)), nil
}
package eventtransformer
// Package eventtransformer provides event transformation logic for different LLM API providers.
// Each provider (e.g., OpenAI, Anthropic) should have its own transformer implementation.
// Transformer is the interface for provider-specific event transformers.
type Transformer interface {
TransformEvent(event map[string]interface{}) (map[string]interface{}, error)
}
// DispatchTransformer returns the appropriate transformer for a given provider.
func DispatchTransformer(provider string) Transformer {
switch provider {
case "openai":
return &OpenAITransformer{}
// case "anthropic":
// return &AnthropicTransformer{}
default:
return nil
}
}
// OpenAITransformer implements Transformer for OpenAI events.
type OpenAITransformer struct{}
package logging
import (
"context"
"os"
"strings"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// NewLogger creates a zap.Logger with the specified level, format, and optional file output.
// level can be debug, info, warn, or error. format can be json or console.
// If filePath is empty, logs are written to stdout.
func NewLogger(level, format, filePath string) (*zap.Logger, error) {
var lvl zapcore.Level
switch strings.ToLower(level) {
case "debug":
lvl = zapcore.DebugLevel
case "info", "":
lvl = zapcore.InfoLevel
case "warn":
lvl = zapcore.WarnLevel
case "error":
lvl = zapcore.ErrorLevel
default:
lvl = zapcore.InfoLevel
}
encCfg := zapcore.EncoderConfig{
TimeKey: "ts",
LevelKey: "level",
NameKey: "logger",
MessageKey: "msg",
CallerKey: "caller",
StacktraceKey: "stacktrace",
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.StringDurationEncoder,
EncodeLevel: zapcore.LowercaseLevelEncoder,
}
var encoder zapcore.Encoder
if strings.ToLower(format) == "console" {
encoder = zapcore.NewConsoleEncoder(encCfg)
} else {
encoder = zapcore.NewJSONEncoder(encCfg)
}
var ws = zapcore.AddSync(os.Stdout)
if filePath != "" {
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return nil, err
}
ws = f
}
core := zapcore.NewCore(encoder, ws, lvl)
return zap.New(core), nil
}
// Context keys for request and correlation IDs
type contextKey string
const (
requestIDKey contextKey = "request_id"
correlationIDKey contextKey = "correlation_id"
)
// Canonical field helpers for structured logging
// RequestFields returns fields for HTTP request logging
func RequestFields(requestID, method, path string, statusCode, durationMs int) []zap.Field {
return []zap.Field{
zap.String("request_id", requestID),
zap.String("method", method),
zap.String("path", path),
zap.Int("status_code", statusCode),
zap.Int("duration_ms", durationMs),
}
}
// CorrelationID returns a field for correlation ID
func CorrelationID(id string) zap.Field {
return zap.String("correlation_id", id)
}
// ProjectID returns a field for project ID
func ProjectID(id string) zap.Field {
return zap.String("project_id", id)
}
// TokenID returns a field for token ID (obfuscated for security)
func TokenID(token string) zap.Field {
return zap.String("token_id", obfuscate.ObfuscateTokenGeneric(token))
}
// ClientIP returns a field for client IP address
func ClientIP(ip string) zap.Field {
return zap.String("client_ip", ip)
}
// Context management for request/correlation IDs
// WithRequestID adds a request ID to the context
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey, requestID)
}
// WithCorrelationID adds a correlation ID to the context
func WithCorrelationID(ctx context.Context, correlationID string) context.Context {
return context.WithValue(ctx, correlationIDKey, correlationID)
}
// GetRequestID retrieves the request ID from context
func GetRequestID(ctx context.Context) (string, bool) {
id, ok := ctx.Value(requestIDKey).(string)
return id, ok
}
// GetCorrelationID retrieves the correlation ID from context
func GetCorrelationID(ctx context.Context) (string, bool) {
id, ok := ctx.Value(correlationIDKey).(string)
return id, ok
}
// Logger enhancement helpers
// WithRequestContext adds request ID from context to logger if present
func WithRequestContext(ctx context.Context, logger *zap.Logger) *zap.Logger {
if requestID, ok := GetRequestID(ctx); ok {
return logger.With(zap.String("request_id", requestID))
}
return logger
}
// WithCorrelationContext adds correlation ID from context to logger if present
func WithCorrelationContext(ctx context.Context, logger *zap.Logger) *zap.Logger {
if correlationID, ok := GetCorrelationID(ctx); ok {
return logger.With(zap.String("correlation_id", correlationID))
}
return logger
}
// NewChildLogger creates a child logger with a component field
func NewChildLogger(parent *zap.Logger, component string) *zap.Logger {
return parent.With(zap.String("component", component))
}
package middleware
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"github.com/sofatutor/llm-proxy/internal/logging"
"go.uber.org/zap"
)
// Middleware defines a function that wraps an http.Handler.
type Middleware func(http.Handler) http.Handler
// ObservabilityConfig controls the behavior of the observability middleware.
type ObservabilityConfig struct {
Enabled bool
EventBus eventbus.EventBus
}
// ObservabilityMiddleware captures request/response data and forwards it to an event bus.
type ObservabilityMiddleware struct {
cfg ObservabilityConfig
logger *zap.Logger
}
// NewObservabilityMiddleware creates a new ObservabilityMiddleware instance.
func NewObservabilityMiddleware(cfg ObservabilityConfig, logger *zap.Logger) *ObservabilityMiddleware {
if logger == nil {
logger = zap.NewNop()
}
return &ObservabilityMiddleware{cfg: cfg, logger: logger}
}
// Middleware returns the http middleware function.
func (m *ObservabilityMiddleware) Middleware() Middleware {
if !m.cfg.Enabled || m.cfg.EventBus == nil {
return func(next http.Handler) http.Handler { return next }
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
crw := &captureResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
var reqBody []byte
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
if r.Body != nil {
// Read and buffer the request body
bodyBytes, err := io.ReadAll(r.Body)
if err == nil {
reqBody = bodyBytes
// Restore the body for downstream handlers
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
}
next.ServeHTTP(crw, r)
// Resolve request ID from header, then context, then response headers
reqID := r.Header.Get("X-Request-ID")
if reqID == "" {
if v, ok := logging.GetRequestID(r.Context()); ok {
reqID = v
}
}
if reqID == "" {
reqID = crw.Header().Get("X-Request-ID")
}
// Skip publishing cache hits (do not incur provider cost)
if v := strings.ToLower(crw.Header().Get("X-PROXY-CACHE")); v == "hit" || v == "conditional-hit" {
return
}
evt := eventbus.Event{
RequestID: reqID,
Method: r.Method,
Path: r.URL.Path,
Status: crw.statusCode,
Duration: time.Since(start),
ResponseHeaders: cloneHeader(crw.Header()),
ResponseBody: crw.body.Bytes(),
RequestBody: reqBody,
}
go m.cfg.EventBus.Publish(context.Background(), evt)
})
}
}
// captureResponseWriter wraps http.ResponseWriter to capture status and body while supporting streaming.
type captureResponseWriter struct {
http.ResponseWriter
statusCode int
body bytes.Buffer
}
func (w *captureResponseWriter) WriteHeader(code int) {
w.statusCode = code
w.ResponseWriter.WriteHeader(code)
}
func (w *captureResponseWriter) Write(b []byte) (int, error) {
w.body.Write(b)
return w.ResponseWriter.Write(b)
}
func (w *captureResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (w *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("hijack not supported")
}
func (w *captureResponseWriter) Push(target string, opts *http.PushOptions) error {
if p, ok := w.ResponseWriter.(http.Pusher); ok {
return p.Push(target, opts)
}
return http.ErrNotSupported
}
func cloneHeader(h http.Header) http.Header {
cloned := make(http.Header, len(h))
for k, v := range h {
vv := make([]string, len(v))
copy(vv, v)
cloned[k] = vv
}
return cloned
}
package middleware
import (
"net/http"
"strings"
"github.com/google/uuid"
"github.com/sofatutor/llm-proxy/internal/logging"
)
// RequestIDMiddleware handles request and correlation ID context propagation
type RequestIDMiddleware struct{}
// NewRequestIDMiddleware creates a new middleware instance for request ID management
func NewRequestIDMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get or generate request ID
requestID := getOrGenerateID(r.Header.Get("X-Request-ID"))
// Get or generate correlation ID
correlationID := getOrGenerateID(r.Header.Get("X-Correlation-ID"))
// Add IDs to context
ctx := logging.WithRequestID(r.Context(), requestID)
ctx = logging.WithCorrelationID(ctx, correlationID)
// Set response headers
w.Header().Set("X-Request-ID", requestID)
w.Header().Set("X-Correlation-ID", correlationID)
// Continue with the request using the enriched context
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// getOrGenerateID returns the provided ID if valid, otherwise generates a new UUID
func getOrGenerateID(existingID string) string {
// Trim whitespace
existingID = strings.TrimSpace(existingID)
// If empty, generate new UUID
if existingID == "" {
return uuid.New().String()
}
// For now, accept any non-empty ID (could add validation later if needed)
return existingID
}
// Package obfuscate centralizes redaction/obfuscation helpers used across the codebase.
package obfuscate
import (
"strings"
)
// ObfuscateTokenGeneric obfuscates arbitrary token-like strings for display/logging.
// Behavior (kept for backward-compat with previous utils implementation):
// - length <= 4 → all asterisks of same length
// - 5..12 → keep first 2 characters, replace the rest with asterisks
// - > 12 → keep first 8 characters, then "...", then last 4 characters
func ObfuscateTokenGeneric(s string) string {
if len(s) <= 4 {
return strings.Repeat("*", len(s))
}
if len(s) <= 12 {
return s[:2] + strings.Repeat("*", len(s)-2)
}
return s[:8] + "..." + s[len(s)-4:]
}
// ObfuscateTokenSimple obfuscates token-like strings with a fixed pattern suitable for UIs.
// Behavior (kept for backward-compat with Admin template helper):
// - length <= 8 → "****"
// - > 8 → first 4 + "****" + last 4
func ObfuscateTokenSimple(s string) string {
if len(s) <= 8 {
return "****"
}
return s[:4] + "****" + s[len(s)-4:]
}
// ObfuscateTokenByPrefix obfuscates tokens that follow a known prefix convention (e.g., "sk-").
// If the string doesn't have the expected prefix, it falls back to the generic strategy.
// Behavior (kept for backward-compat with token.ObfuscateToken):
// - With prefix: keep the prefix, then show first 4 and last 4 of the remainder, replacing the middle with '*'
// If remainder <= 8, return the input unmodified
// - Without prefix: same result as ObfuscateTokenGeneric
func ObfuscateTokenByPrefix(s string, prefix string) string {
if s == "" {
return s
}
if !strings.HasPrefix(s, prefix) {
// Preserve previous behavior of token.ObfuscateToken: leave non-prefixed strings unchanged
return s
}
rest := s[len(prefix):]
if len(rest) <= 8 {
return s
}
visible := 4
first := rest[:visible]
last := rest[len(rest)-visible:]
middle := strings.Repeat("*", len(rest)-(visible*2))
return prefix + first + middle + last
}
package proxy
import (
"net/http"
"strings"
"sync"
"time"
)
type cachedResponse struct {
statusCode int
headers http.Header
body []byte
expiresAt time.Time
vary string // Vary header from upstream response for per-response cache key generation
}
// httpCache is a minimal cache interface used by the proxy cache layer.
// Implementations must be safe for concurrent use.
type httpCache interface {
Get(key string) (cachedResponse, bool)
Set(key string, value cachedResponse)
Purge(key string) bool // Remove exact key
PurgePrefix(prefix string) int // Remove all keys with prefix, return count
}
type inMemoryCache struct {
mu sync.RWMutex
store map[string]cachedResponse
}
func newInMemoryCache() *inMemoryCache {
return &inMemoryCache{store: make(map[string]cachedResponse)}
}
func (c *inMemoryCache) Get(key string) (cachedResponse, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
v, ok := c.store[key]
if !ok {
return cachedResponse{}, false
}
if time.Now().After(v.expiresAt) {
return cachedResponse{}, false
}
return v, true
}
func (c *inMemoryCache) Set(key string, value cachedResponse) {
c.mu.Lock()
c.store[key] = value
c.mu.Unlock()
}
func (c *inMemoryCache) Purge(key string) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, exists := c.store[key]
if exists {
delete(c.store, key)
}
return exists
}
func (c *inMemoryCache) PurgePrefix(prefix string) int {
c.mu.Lock()
defer c.mu.Unlock()
count := 0
for key := range c.store {
if strings.HasPrefix(key, prefix) {
delete(c.store, key)
count++
}
}
return count
}
package proxy
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"net/textproto"
"sort"
"strconv"
"strings"
"time"
)
// generateCacheKey builds a cache key from method, path, query and an optional
// ordered list of header names to include. When headersToInclude is nil or empty,
// no header values are incorporated. For methods carrying a body (POST/PUT/PATCH)
// this function also incorporates X-Body-Hash and an optional TTL derived from
// Cache-Control (public, max-age/s-maxage) to avoid collisions across different TTLs.
func generateCacheKey(r *http.Request, headersToInclude []string) string {
// Key base: METHOD|PATH|sorted(query)
// Host/scheme are intentionally excluded to keep keys stable across proxy ↔ upstream phases.
b := strings.Builder{}
b.WriteString(r.Method)
b.WriteString("|")
b.WriteString(r.URL.Path)
b.WriteString("|")
// Sorted query to normalize key
keys := make([]string, 0, len(r.URL.Query()))
for k := range r.URL.Query() {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vals := r.URL.Query()[k]
sort.Strings(vals)
b.WriteString(k)
b.WriteString("=")
b.WriteString(strings.Join(vals, ","))
b.WriteString("&")
}
raw := b.String()
sum := sha256.Sum256([]byte(raw))
baseKey := hex.EncodeToString(sum[:])
// Include selected request header values (normalized) deterministically
vb := strings.Builder{}
if len(headersToInclude) > 0 {
for _, hk := range headersToInclude {
if v := r.Header.Get(hk); v != "" {
vb.WriteString(strings.ToLower(hk))
vb.WriteString(":")
vb.WriteString(strings.TrimSpace(v))
vb.WriteString("|")
}
}
}
vraw := baseKey + vb.String()
vsum := sha256.Sum256([]byte(vraw))
varyKey := hex.EncodeToString(vsum[:])
final := strings.Builder{}
final.WriteString(varyKey)
// For methods with body, include X-Body-Hash when present (computed in proxy)
if r.Header.Get("X-Body-Hash") != "" {
final.WriteString("|body=")
final.WriteString(r.Header.Get("X-Body-Hash"))
}
// If client explicitly opts into shared caching with a TTL (public + max-age/s-maxage),
// include the requested TTL in the key so different TTLs do not collide with older entries.
// Only apply this for methods that can carry a body (POST/PUT/PATCH) to avoid splitting
// GET/HEAD cache keys unnecessarily when origin already provides TTLs.
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
cc := parseCacheControl(r.Header.Get("Cache-Control"))
if cc.publicCache && (cc.sMaxAge > 0 || cc.maxAge > 0) {
final.WriteString("|ttl=")
if cc.sMaxAge > 0 {
final.WriteString("smax=")
final.WriteString(strconv.Itoa(cc.sMaxAge))
} else {
final.WriteString("max=")
final.WriteString(strconv.Itoa(cc.maxAge))
}
}
}
return final.String()
}
func CacheKeyFromRequest(r *http.Request) string {
// Conservative subset until per-response Vary handling is applied by caller
return generateCacheKey(r, []string{"Accept", "Accept-Encoding", "Accept-Language"})
}
// CacheKeyFromRequestWithVary generates a cache key using the specified Vary header
// from the upstream response. This enables per-response driven cache key generation.
//
// Behavior:
// - When vary is empty or "*", this function returns a key that ignores header values
// (treated as "no vary" for key generation purposes).
// - Header names parsed from vary are normalized to canonical MIME header names.
func CacheKeyFromRequestWithVary(r *http.Request, vary string) string {
var headers []string
if vary != "" && vary != "*" {
headers = parseVaryHeader(vary)
sort.Strings(headers) // ensure consistent ordering
}
return generateCacheKey(r, headers)
}
// parseVaryHeader parses a Vary header value and returns the list of header names.
// It handles comma-separated values and normalizes header names.
func parseVaryHeader(vary string) []string {
if vary == "" || vary == "*" {
return nil
}
var headers []string
parts := strings.Split(vary, ",")
for _, part := range parts {
header := strings.TrimSpace(part)
if header != "" {
headers = append(headers, normalizeHeaderName(header))
}
}
return headers
}
// normalizeHeaderName canonicalizes a header name for consistent lookups.
func normalizeHeaderName(name string) string {
return textproto.CanonicalMIMEHeaderKey(name)
}
func isResponseCacheable(res *http.Response) bool {
if res == nil {
return false
}
cc := parseCacheControl(res.Header.Get("Cache-Control"))
if cc.noStore || cc.privateCache {
return false
}
// Cacheable status codes (basic set)
switch res.StatusCode {
case 200, 203, 301, 308, 404, 410:
// ok
default:
return false
}
// If Authorization was present on request, require explicit shared cache directives
if res.Request != nil {
if res.Request.Header.Get("Authorization") != "" {
if !cc.publicCache && cc.sMaxAge <= 0 {
return false
}
}
}
// Don't cache SSE
if strings.Contains(res.Header.Get("Content-Type"), "text/event-stream") {
return false
}
if res.Header.Get("Vary") == "*" {
return false
}
return true
}
type cacheControl struct {
noStore bool
noCache bool
mustReval bool
maxAge int
sMaxAge int
publicCache bool
privateCache bool
}
func parseCacheControl(v string) cacheControl {
cc := cacheControl{}
parts := strings.Split(v, ",")
for _, p := range parts {
p = strings.TrimSpace(strings.ToLower(p))
switch {
case p == "no-store":
cc.noStore = true
case p == "no-cache":
cc.noCache = true
case p == "must-revalidate":
cc.mustReval = true
case p == "public":
cc.publicCache = true
case p == "private":
cc.privateCache = true
case strings.HasPrefix(p, "s-maxage="):
cc.sMaxAge = atoiSafe(strings.TrimPrefix(p, "s-maxage="))
case strings.HasPrefix(p, "max-age="):
cc.maxAge = atoiSafe(strings.TrimPrefix(p, "max-age="))
}
}
return cc
}
func cacheTTLFromHeaders(res *http.Response, defaultTTL time.Duration) time.Duration {
cc := parseCacheControl(res.Header.Get("Cache-Control"))
if cc.noStore {
return 0
}
if cc.sMaxAge > 0 {
return time.Duration(cc.sMaxAge) * time.Second
}
if cc.maxAge > 0 {
return time.Duration(cc.maxAge) * time.Second
}
if defaultTTL > 0 && (cc.publicCache || (!cc.privateCache && !cc.noCache)) {
return defaultTTL
}
return 0
}
// requestForcedCacheTTL returns a TTL requested by the client via Cache-Control
// when the client explicitly asks for shared caching (public) and provides a TTL.
// This is primarily used for benchmarking when upstream does not send cache hints.
func requestForcedCacheTTL(req *http.Request) time.Duration {
if req == nil {
return 0
}
cc := parseCacheControl(req.Header.Get("Cache-Control"))
if !cc.publicCache {
return 0
}
if cc.sMaxAge > 0 {
return time.Duration(cc.sMaxAge) * time.Second
}
if cc.maxAge > 0 {
return time.Duration(cc.maxAge) * time.Second
}
return 0
}
func atoiSafe(s string) int {
n := 0
for i := 0; i < len(s); i++ {
ch := s[i]
if ch < '0' || ch > '9' {
break
}
n = n*10 + int(ch-'0')
}
return n
}
func cloneHeadersForCache(h http.Header) http.Header {
// Drop hop-by-hop headers
drop := map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"TE": {},
"Trailers": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
out := http.Header{}
for k, vals := range h {
if _, ok := drop[k]; ok {
continue
}
for _, v := range vals {
out.Add(k, v)
}
}
return out
}
// canServeCachedForRequest decides if a cached response is reusable for the given request.
// In particular, for requests with Authorization, only allow reuse when the cached
// response explicitly allows shared caching (public or s-maxage>0).
func canServeCachedForRequest(r *http.Request, cachedHeaders http.Header) bool {
if r == nil {
return false
}
if r.Header.Get("Authorization") == "" {
return true
}
cc := parseCacheControl(cachedHeaders.Get("Cache-Control"))
if cc.publicCache || cc.sMaxAge > 0 {
return true
}
return false
}
// conditionalRequestMatches returns true if the client's conditional headers
// (If-None-Match or If-Modified-Since) match the cached response headers.
// RFC semantics simplified for strong validators; good enough for proxy cache use.
func conditionalRequestMatches(r *http.Request, cachedHeaders http.Header) bool {
if r == nil {
return false
}
// If-None-Match takes precedence over If-Modified-Since
if inm := strings.TrimSpace(r.Header.Get("If-None-Match")); inm != "" {
// Canonicalize header name and compare values (support multiple via comma)
etag := strings.TrimSpace(cachedHeaders.Get(textproto.CanonicalMIMEHeaderKey("ETag")))
if etag == "" {
return false
}
// Strip surrounding quotes for comparison robustness
ce := strings.Trim(etag, "\"")
for _, part := range strings.Split(inm, ",") {
p := strings.TrimSpace(part)
p = strings.Trim(p, "\"")
if p == "*" || p == ce {
return true
}
}
return false
}
if ims := strings.TrimSpace(r.Header.Get("If-Modified-Since")); ims != "" {
lm := strings.TrimSpace(cachedHeaders.Get(textproto.CanonicalMIMEHeaderKey("Last-Modified")))
if lm == "" {
return false
}
imsTime, err1 := http.ParseTime(ims)
lmTime, err2 := http.ParseTime(lm)
if err1 != nil || err2 != nil {
return false
}
if !lmTime.After(imsTime) {
return true
}
}
return false
}
// hasClientCacheOptIn returns true if the client request explicitly opts into
// shared caching via Cache-Control (public with max-age or s-maxage > 0).
func hasClientCacheOptIn(r *http.Request) bool {
if r == nil {
return false
}
cc := parseCacheControl(r.Header.Get("Cache-Control"))
if !cc.publicCache {
return false
}
if cc.sMaxAge > 0 {
return true
}
if cc.maxAge > 0 {
return true
}
return false
}
// Note: client conditional inspection is handled by conditionalRequestMatches; removed redundant helper.
// wantsRevalidation returns true if the client requests origin revalidation
// (e.g., Cache-Control: no-cache or max-age=0).
func wantsRevalidation(r *http.Request) bool {
if r == nil {
return false
}
ccVal := strings.ToLower(r.Header.Get("Cache-Control"))
if ccVal == "" {
return false
}
if strings.Contains(ccVal, "no-cache") {
return true
}
if strings.Contains(ccVal, "max-age=0") {
return true
}
return false
}
package proxy
import (
"context"
"encoding/json"
"os"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
// redisCache implements httpCache using Redis.
// It stores cachedResponse as JSON and uses Redis TTL for expiration.
type redisCache struct {
client *redis.Client
prefix string
scanCount int
}
func newRedisCache(client *redis.Client, keyPrefix string) *redisCache {
if keyPrefix == "" {
keyPrefix = "llmproxy:cache:"
}
// Determine SCAN batch size from env (default 2048)
scan := 2048
if v := os.Getenv("REDIS_SCAN_COUNT"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
scan = n
}
}
return &redisCache{client: client, prefix: keyPrefix, scanCount: scan}
}
type redisCachedResponse struct {
StatusCode int `json:"status_code"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body"`
Vary string `json:"vary"` // Vary header for per-response cache key generation
}
func (r *redisCache) Get(key string) (cachedResponse, bool) {
ctx := context.Background()
data, err := r.client.Get(ctx, r.prefix+key).Bytes()
if err != nil {
return cachedResponse{}, false
}
var rc redisCachedResponse
if err := json.Unmarshal(data, &rc); err != nil {
return cachedResponse{}, false
}
// Convert map to http.Header lazily in caller; keep simple here
hdr := make(map[string][]string, len(rc.Headers))
for k, v := range rc.Headers {
hdr[k] = v
}
return cachedResponse{
statusCode: rc.StatusCode,
headers: hdr,
body: rc.Body,
vary: rc.Vary, // Include vary field
// expiresAt not needed; Redis TTL enforces expiry
expiresAt: time.Now().Add(time.Second),
}, true
}
func (r *redisCache) Set(key string, value cachedResponse) {
ctx := context.Background()
// Serialize
ser := redisCachedResponse{StatusCode: value.statusCode, Headers: value.headers, Body: value.body, Vary: value.vary}
payload, err := json.Marshal(ser)
if err != nil {
return
}
ttl := time.Until(value.expiresAt)
if ttl <= 0 {
return
}
_ = r.client.Set(ctx, r.prefix+key, payload, ttl).Err()
}
// Purge removes a single cache entry by exact key. Returns true if deleted.
func (r *redisCache) Purge(key string) bool {
ctx := context.Background()
res := r.client.Del(ctx, r.prefix+key)
n, _ := res.Result()
return n > 0
}
// PurgePrefix removes all cache entries whose keys start with the given prefix.
// Returns number of deleted keys. Uses SCAN to avoid blocking Redis.
func (r *redisCache) PurgePrefix(prefix string) int {
ctx := context.Background()
fullPrefix := r.prefix + prefix
var cursor uint64
total := 0
for {
keys, next, err := r.client.Scan(ctx, cursor, fullPrefix+"*", int64(r.scanCount)).Result()
if err != nil {
// Abort on scan error to avoid infinite loop; return what we deleted so far
return total
}
cursor = next
if len(keys) > 0 {
delCount, _ := r.client.Del(ctx, keys...).Result()
total += int(delCount)
}
if cursor == 0 {
break
}
}
return total
}
// Package proxy provides the transparent reverse proxy implementation.
package proxy
import (
"context"
"sync"
"time"
"go.uber.org/zap"
)
// CacheStatsStore defines the interface for persisting cache hit counts.
//
// Consistency invariant: Under normal operation, CacheHitCount ≤ RequestCount should
// always hold for any token. Cache hits are only recorded when responses are served
// from cache, which requires a prior upstream request that incremented RequestCount.
//
// If CacheHitCount > RequestCount is observed, it indicates a system issue that should
// be investigated (e.g., request count increment failed while cache hit was recorded).
// The Admin UI uses safeSub to display max(0, RequestCount-CacheHitCount) to handle
// this edge case gracefully.
type CacheStatsStore interface {
// IncrementCacheHitCountBatch increments cache_hit_count for multiple tokens.
// The deltas map has token IDs as keys and increment values as values.
IncrementCacheHitCountBatch(ctx context.Context, deltas map[string]int) error
}
// CacheStatsAggregatorConfig holds configuration for the cache stats aggregator.
type CacheStatsAggregatorConfig struct {
BufferSize int // Size of the buffered channel (default: 1000)
FlushInterval time.Duration // How often to flush stats to DB (default: 5s)
BatchSize int // Max events before flush (default: 100)
}
// DefaultCacheStatsAggregatorConfig returns the default configuration.
func DefaultCacheStatsAggregatorConfig() CacheStatsAggregatorConfig {
return CacheStatsAggregatorConfig{
BufferSize: 1000,
FlushInterval: 5 * time.Second,
BatchSize: 100,
}
}
// CacheStatsAggregator aggregates cache hit events and periodically flushes them to the database.
// It uses a buffered channel for non-blocking enqueue and drops events when the buffer is full.
type CacheStatsAggregator struct {
config CacheStatsAggregatorConfig
store CacheStatsStore
logger *zap.Logger
eventsCh chan string // channel of token IDs
stopCh chan struct{}
doneCh chan struct{}
mu sync.RWMutex
stopped bool
}
// NewCacheStatsAggregator creates a new aggregator with the given configuration.
func NewCacheStatsAggregator(config CacheStatsAggregatorConfig, store CacheStatsStore, logger *zap.Logger) *CacheStatsAggregator {
if config.BufferSize <= 0 {
config.BufferSize = 1000
}
if config.FlushInterval <= 0 {
config.FlushInterval = 5 * time.Second
}
if config.BatchSize <= 0 {
config.BatchSize = 100
}
if logger == nil {
logger = zap.NewNop()
}
return &CacheStatsAggregator{
config: config,
store: store,
logger: logger,
eventsCh: make(chan string, config.BufferSize),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start begins the background aggregation worker.
func (a *CacheStatsAggregator) Start() {
go a.run()
}
// Stop gracefully shuts down the aggregator, flushing any pending stats.
func (a *CacheStatsAggregator) Stop(ctx context.Context) error {
a.mu.Lock()
if a.stopped {
a.mu.Unlock()
return nil
}
a.stopped = true
a.mu.Unlock()
close(a.stopCh)
select {
case <-a.doneCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// RecordCacheHit enqueues a cache hit event for the given token.
// This is non-blocking; if the buffer is full, the event is dropped.
func (a *CacheStatsAggregator) RecordCacheHit(tokenID string) {
if tokenID == "" {
return
}
a.mu.RLock()
stopped := a.stopped
a.mu.RUnlock()
if stopped {
return
}
select {
case a.eventsCh <- tokenID:
// Event enqueued successfully
default:
// Buffer full, drop the event
a.logger.Debug("cache stats buffer full, dropping event",
zap.String("token_id", tokenID))
}
}
// run is the main loop of the aggregator worker.
func (a *CacheStatsAggregator) run() {
defer close(a.doneCh)
ticker := time.NewTicker(a.config.FlushInterval)
defer ticker.Stop()
deltas := make(map[string]int)
eventCount := 0
flush := func() {
if len(deltas) == 0 {
return
}
// Copy and reset
toFlush := deltas
deltas = make(map[string]int)
flushedCount := eventCount
eventCount = 0
// Flush with a short timeout
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := a.store.IncrementCacheHitCountBatch(ctx, toFlush); err != nil {
a.logger.Warn("failed to flush cache hit stats",
zap.Error(err),
zap.Int("event_count", flushedCount),
zap.Int("token_count", len(toFlush)))
// On error, we drop the stats (lossy-tolerant as per design)
} else {
a.logger.Debug("flushed cache hit stats",
zap.Int("event_count", flushedCount),
zap.Int("token_count", len(toFlush)))
}
}
for {
select {
case <-a.stopCh:
// Drain remaining events
for {
select {
case tokenID := <-a.eventsCh:
deltas[tokenID]++
eventCount++
default:
// No more events
flush()
return
}
}
case tokenID := <-a.eventsCh:
deltas[tokenID]++
eventCount++
if eventCount >= a.config.BatchSize {
flush()
}
case <-ticker.C:
flush()
}
}
}
package proxy
import (
"net/http"
"sync"
"time"
)
// Define a custom context key type for circuit breaker cooldown override
type cbContextKey string
const cbCooldownOverrideKey cbContextKey = "circuitbreaker_cooldown_override"
// CircuitBreakerMiddleware returns a middleware that opens the circuit after N consecutive failures.
// While open, it returns 503 immediately. After a cooldown, it closes and allows requests again.
func CircuitBreakerMiddleware(failureThreshold int, cooldown time.Duration, isTransient func(status int) bool) Middleware {
cb := &circuitBreaker{
failureThreshold: failureThreshold,
cooldown: cooldown,
isTransient: isTransient,
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cb.mu.Lock()
if cb.open {
// Allow test override of cooldown via context
if override, ok := r.Context().Value(cbCooldownOverrideKey).(time.Duration); ok {
cb.cooldown = override
}
if time.Since(cb.openedAt) < cb.cooldown {
cb.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("{\"error\":\"Upstream unavailable (circuit breaker open)\"}")) // Ignore error: nothing we can do if write fails
return
}
// Cooldown expired, close circuit
cb.open = false
cb.failureCount = 0
}
cb.mu.Unlock()
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rec, r)
if cb.isTransient(rec.statusCode) {
cb.mu.Lock()
cb.failureCount++
if cb.failureCount >= cb.failureThreshold {
cb.open = true
cb.openedAt = time.Now()
}
cb.mu.Unlock()
} else {
cb.mu.Lock()
cb.failureCount = 0
cb.mu.Unlock()
}
})
}
}
type circuitBreaker struct {
mu sync.Mutex
open bool
failureThreshold int
cooldown time.Duration
isTransient func(status int) bool
openedAt time.Time
failureCount int
}
// Package proxy provides the transparent proxy functionality for the LLM API.
package proxy
import (
"fmt"
"os"
"time"
"gopkg.in/yaml.v3"
)
// APIConfig represents the top-level configuration for API proxying
type APIConfig struct {
// APIs contains configurations for different API providers
APIs map[string]*APIProviderConfig `yaml:"apis"`
// DefaultAPI is the default API provider to use if not specified
DefaultAPI string `yaml:"default_api"`
}
// APIProviderConfig represents the configuration for a specific API provider
type APIProviderConfig struct {
// BaseURL is the target API base URL
BaseURL string `yaml:"base_url"`
// AllowedEndpoints is a list of endpoint prefixes that are allowed to be accessed
AllowedEndpoints []string `yaml:"allowed_endpoints"`
// AllowedMethods is a list of HTTP methods that are allowed
AllowedMethods []string `yaml:"allowed_methods"`
// Timeouts for various operations
Timeouts TimeoutConfig `yaml:"timeouts"`
// Connection settings
Connection ConnectionConfig `yaml:"connection"`
// ParamWhitelist is a map of parameter names to allowed values
ParamWhitelist map[string][]string `yaml:"param_whitelist"`
AllowedOrigins []string `yaml:"allowed_origins"`
RequiredHeaders []string `yaml:"required_headers"`
}
// TimeoutConfig contains timeout settings for the proxy
type TimeoutConfig struct {
// Request is the overall request timeout
Request time.Duration `yaml:"request"`
// ResponseHeader is the timeout for receiving response headers
ResponseHeader time.Duration `yaml:"response_header"`
// IdleConnection is the timeout for idle connections
IdleConnection time.Duration `yaml:"idle_connection"`
// FlushInterval controls how often to flush streaming responses
FlushInterval time.Duration `yaml:"flush_interval"`
}
// ConnectionConfig contains connection settings for the proxy
type ConnectionConfig struct {
// MaxIdleConns is the maximum number of idle connections
MaxIdleConns int `yaml:"max_idle_conns"`
// MaxIdleConnsPerHost is the maximum number of idle connections per host
MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host"`
}
// LoadAPIConfigFromFile loads API configuration from a YAML file
func LoadAPIConfigFromFile(filePath string) (*APIConfig, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config APIConfig
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Validate the configuration
if err := validateAPIConfig(&config); err != nil {
return nil, err
}
return &config, nil
}
// validateAPIConfig checks that the API configuration is valid
func validateAPIConfig(config *APIConfig) error {
if len(config.APIs) == 0 {
return fmt.Errorf("no API providers configured")
}
if config.DefaultAPI != "" {
if _, exists := config.APIs[config.DefaultAPI]; !exists {
return fmt.Errorf("default API '%s' not found in configured APIs", config.DefaultAPI)
}
}
for name, api := range config.APIs {
if api.BaseURL == "" {
return fmt.Errorf("API '%s' has empty base_url", name)
}
if len(api.AllowedEndpoints) == 0 {
return fmt.Errorf("API '%s' has no allowed_endpoints", name)
}
if len(api.AllowedMethods) == 0 {
return fmt.Errorf("API '%s' has no allowed_methods", name)
}
}
return nil
}
// GetProxyConfigForAPI returns a ProxyConfig for the specified API provider
func (c *APIConfig) GetProxyConfigForAPI(apiName string) (*ProxyConfig, error) {
// If no API name specified, use the default
if apiName == "" {
apiName = c.DefaultAPI
}
// Find the API configuration
apiConfig, exists := c.APIs[apiName]
if !exists {
return nil, fmt.Errorf("API provider '%s' not found in configuration", apiName)
}
// Create the proxy configuration
proxyConfig := ProxyConfig{
TargetBaseURL: apiConfig.BaseURL,
AllowedEndpoints: apiConfig.AllowedEndpoints,
AllowedMethods: apiConfig.AllowedMethods,
RequestTimeout: apiConfig.Timeouts.Request,
ResponseHeaderTimeout: apiConfig.Timeouts.ResponseHeader,
FlushInterval: apiConfig.Timeouts.FlushInterval,
IdleConnTimeout: apiConfig.Timeouts.IdleConnection,
MaxIdleConns: apiConfig.Connection.MaxIdleConns,
MaxIdleConnsPerHost: apiConfig.Connection.MaxIdleConnsPerHost,
ParamWhitelist: apiConfig.ParamWhitelist,
AllowedOrigins: apiConfig.AllowedOrigins,
RequiredHeaders: apiConfig.RequiredHeaders,
}
return &proxyConfig, nil
}
package proxy
import (
"context"
"errors"
"net/http"
"time"
"github.com/sofatutor/llm-proxy/internal/audit"
)
// TokenValidator defines the interface for token validation
type TokenValidator interface {
// ValidateToken validates a token and returns the associated project ID
ValidateToken(ctx context.Context, token string) (string, error)
// ValidateTokenWithTracking validates a token, increments its usage, and returns the project ID
ValidateTokenWithTracking(ctx context.Context, token string) (string, error)
}
// ProjectStore defines the interface for retrieving and managing project information
// (extended for management API)
type ProjectStore interface {
// GetAPIKeyForProject retrieves the API key for a project
GetAPIKeyForProject(ctx context.Context, projectID string) (string, error)
// GetProjectActive checks if a project is active
GetProjectActive(ctx context.Context, projectID string) (bool, error)
// Management API CRUD
ListProjects(ctx context.Context) ([]Project, error)
CreateProject(ctx context.Context, project Project) error
GetProjectByID(ctx context.Context, projectID string) (Project, error)
UpdateProject(ctx context.Context, project Project) error
DeleteProject(ctx context.Context, projectID string) error
}
// ProjectActiveChecker defines the interface for checking project active status
type ProjectActiveChecker interface {
// GetProjectActive checks if a project is active
GetProjectActive(ctx context.Context, projectID string) (bool, error)
}
// AuditLogger defines the interface for audit event logging
type AuditLogger interface {
// Log records an audit event
Log(event *audit.Event) error
}
// Proxy defines the interface for a transparent HTTP proxy
type Proxy interface {
// Handler returns an http.Handler for the proxy
Handler() http.Handler
// Shutdown gracefully shuts down the proxy
Shutdown(ctx context.Context) error
}
// ProxyConfig contains configuration for the proxy
type ProxyConfig struct {
// TargetBaseURL is the base URL of the API to proxy to
TargetBaseURL string
// AllowedEndpoints is a whitelist of endpoints that can be accessed
AllowedEndpoints []string
// AllowedMethods is a whitelist of HTTP methods that can be used
AllowedMethods []string
// RequestTimeout is the maximum duration for a complete request
RequestTimeout time.Duration
// ResponseHeaderTimeout is the time to wait for response headers
ResponseHeaderTimeout time.Duration
// FlushInterval is how often to flush streaming responses
FlushInterval time.Duration
// MaxIdleConns is the maximum number of idle connections
MaxIdleConns int
// MaxIdleConnsPerHost is the maximum number of idle connections per host
MaxIdleConnsPerHost int
// IdleConnTimeout is how long to keep idle connections alive
IdleConnTimeout time.Duration
// LogLevel controls the verbosity of logging
LogLevel string
// LogFormat controls the log output format (json or console)
LogFormat string
// LogFile specifies a file path for logs (stdout if empty)
LogFile string
// SetXForwardedFor determines whether to set the X-Forwarded-For header
SetXForwardedFor bool
// ParamWhitelist is a map of parameter names to allowed values
ParamWhitelist map[string][]string
// AllowedOrigins is a list of allowed CORS origins for this provider
AllowedOrigins []string
// RequiredHeaders is a list of required request headers (case-insensitive)
RequiredHeaders []string
// Project active guard configuration
EnforceProjectActive bool // Whether to enforce project active status
// --- HTTP cache (global, opt-in; set programmatically, not via YAML) ---
// HTTPCacheEnabled toggles the proxy cache for GET/HEAD based on HTTP semantics
HTTPCacheEnabled bool
// HTTPCacheDefaultTTL is used only when upstream allows caching but omits explicit TTL
HTTPCacheDefaultTTL time.Duration
// HTTPCacheMaxObjectBytes is a guardrail for maximum cacheable response size
HTTPCacheMaxObjectBytes int64
// RedisCacheURL enables Redis-backed cache when non-empty (e.g., redis://localhost:6379/0)
RedisCacheURL string
// RedisCacheKeyPrefix allows namespacing cache keys (default: llmproxy:cache:)
RedisCacheKeyPrefix string
}
// Validate checks that the ProxyConfig is valid and returns an error if not.
func (c *ProxyConfig) Validate() error {
if c.TargetBaseURL == "" {
return errors.New("TargetBaseURL must not be empty")
}
if len(c.AllowedMethods) == 0 {
return errors.New("AllowedMethods must not be empty")
}
if len(c.AllowedEndpoints) == 0 {
return errors.New("AllowedEndpoints must not be empty")
}
return nil
}
// ErrorResponse is the standard format for error responses
type ErrorResponse struct {
Error string `json:"error"`
Description string `json:"description,omitempty"`
Code string `json:"code,omitempty"`
}
// Middleware defines a function that wraps an http.Handler
type Middleware func(http.Handler) http.Handler
// Chain applies a series of middleware to a handler
func Chain(h http.Handler, middleware ...Middleware) http.Handler {
for i := len(middleware) - 1; i >= 0; i-- {
h = middleware[i](h)
}
return h
}
// contextKey is a type for context keys
type contextKey string
// version is the current version of the proxy
const version = "0.1.0"
const (
// ctxKeyRequestID is the context key for the request ID
ctxKeyRequestID contextKey = "request_id"
// ctxKeyProjectID is the context key for the project ID
ctxKeyProjectID contextKey = "project_id"
// ctxKeyTokenID is the context key for the token ID (used for cache stats)
ctxKeyTokenID contextKey = "token_id"
// ctxKeyLogger is the context key for a request-scoped logger
ctxKeyLogger contextKey = "logger"
// ctxKeyOriginalPath stores the original request path before proxy rewriting
ctxKeyOriginalPath contextKey = "original_path"
// ctxKeyValidationError carries token validation error (if any)
ctxKeyValidationError contextKey = "validation_error"
// Timing keys for observability
ctxKeyProxyReceivedAt contextKey = "proxy_received_at"
ctxKeyProxySentBackendAt contextKey = "proxy_sent_backend_at"
ctxKeyProxyFirstRespAt contextKey = "proxy_first_resp_at"
ctxKeyProxyFinalRespAt contextKey = "proxy_final_resp_at"
// ctxKeyRequestStart marks the time when a handler started processing
ctxKeyRequestStart contextKey = "request_start"
)
// Project represents a project for the management API and proxy
// (copied from database/models.go)
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
OpenAIAPIKey string `json:"openai_api_key"`
IsActive bool `json:"is_active"`
DeactivatedAt *time.Time `json:"deactivated_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
package proxy
import (
"context"
"encoding/json"
"net/http"
"strings"
"github.com/sofatutor/llm-proxy/internal/audit"
"github.com/sofatutor/llm-proxy/internal/logging"
"go.uber.org/zap"
)
// shouldAllowProject determines whether the request should proceed based on project active status.
// It returns (allowed, statusCode, errorResponse). When allowed is true, statusCode and errorResponse are ignored.
// It also emits audit events for denied or error cases.
func shouldAllowProject(ctx context.Context, enforceActive bool, checker ProjectActiveChecker, projectID string, auditLogger AuditLogger, r *http.Request) (bool, int, ErrorResponse) {
if !enforceActive {
return true, 0, ErrorResponse{}
}
// Get request metadata for audit events
requestID, _ := logging.GetRequestID(ctx)
clientIP := getClientIP(r)
userAgent := r.UserAgent()
isActive, err := checker.GetProjectActive(ctx, projectID)
if err != nil {
if logger := getLoggerFromContext(ctx); logger != nil {
logger.Error("Failed to check project active status",
zap.String("project_id", projectID),
zap.Error(err))
}
// Emit audit event for service unavailable
if auditLogger != nil {
auditEvent := audit.NewEvent(audit.ActionProxyRequest, audit.ActorSystem, audit.ResultError).
WithProjectID(projectID).
WithRequestID(requestID).
WithClientIP(clientIP).
WithUserAgent(userAgent).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithReason("service_unavailable").
WithError(err)
_ = auditLogger.Log(auditEvent)
}
return false, http.StatusServiceUnavailable, ErrorResponse{Error: "Service temporarily unavailable", Code: "service_unavailable"}
}
if !isActive {
// Emit audit event for project inactive denial
if auditLogger != nil {
auditEvent := audit.NewEvent(audit.ActionProxyRequest, audit.ActorSystem, audit.ResultDenied).
WithProjectID(projectID).
WithRequestID(requestID).
WithClientIP(clientIP).
WithUserAgent(userAgent).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithReason("project_inactive")
_ = auditLogger.Log(auditEvent)
}
return false, http.StatusForbidden, ErrorResponse{Error: "Project is inactive", Code: "project_inactive"}
}
return true, 0, ErrorResponse{}
}
// ProjectActiveGuardMiddleware creates middleware that enforces project active status
// If enforceActive is false, the middleware passes through all requests without checking
// If enforceActive is true, inactive projects receive a 403 Forbidden response
func ProjectActiveGuardMiddleware(enforceActive bool, checker ProjectActiveChecker, auditLogger AuditLogger) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get project ID from context (should be set by token validation middleware)
projectIDValue := r.Context().Value(ctxKeyProjectID)
if projectIDValue == nil {
writeErrorResponse(w, http.StatusInternalServerError, ErrorResponse{
Error: "Internal server error",
Code: "internal_error",
Description: "missing project ID in request context",
})
return
}
projectID, ok := projectIDValue.(string)
if !ok || projectID == "" {
writeErrorResponse(w, http.StatusInternalServerError, ErrorResponse{
Error: "Internal server error",
Code: "internal_error",
Description: "invalid project ID in request context",
})
return
}
if allowed, status, er := shouldAllowProject(r.Context(), enforceActive, checker, projectID, auditLogger, r); !allowed {
writeErrorResponse(w, status, er)
return
}
// Project is active, continue with request
next.ServeHTTP(w, r)
})
}
}
// writeErrorResponse writes a JSON error response
func writeErrorResponse(w http.ResponseWriter, statusCode int, errorResp ErrorResponse) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(errorResp); err != nil {
// No request context available here, so we cannot access a scoped logger safely
// Consider hooking a global logger if needed; for now we silently ignore to avoid panics
_ = err
}
}
// getLoggerFromContext tries to get a logger from context
// This is a helper function that works with the existing logging context
func getLoggerFromContext(ctx context.Context) *zap.Logger {
// Try to get logger from context if available
// If not available, return nil and let caller handle gracefully
if loggerValue := ctx.Value(ctxKeyLogger); loggerValue != nil {
if logger, ok := loggerValue.(*zap.Logger); ok {
return logger
}
}
return nil
}
// getClientIP extracts the client IP address from the request
// It checks X-Forwarded-For, X-Real-IP headers and falls back to RemoteAddr
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first (comma-separated list)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP from the list
if idx := strings.Index(xff, ","); idx >= 0 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fallback to RemoteAddr (remove port if present)
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx >= 0 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
package proxy
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"path"
"strings"
"sync"
"time"
"crypto/sha256"
"encoding/hex"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/sofatutor/llm-proxy/internal/logging"
"github.com/sofatutor/llm-proxy/internal/middleware"
"github.com/sofatutor/llm-proxy/internal/token"
"go.uber.org/zap"
)
// TransparentProxy implements the Proxy interface for transparent proxying
type TransparentProxy struct {
config ProxyConfig
tokenValidator TokenValidator
projectStore ProjectStore
logger *zap.Logger
auditLogger AuditLogger
metrics *ProxyMetrics
proxy *httputil.ReverseProxy
httpServer *http.Server
shuttingDown bool
mu sync.RWMutex
allowedMethodsHeader string // cached comma-separated allowed methods
obsMiddleware *middleware.ObservabilityMiddleware
cache httpCache
cacheStatsAggregator *CacheStatsAggregator
}
// ProxyMetrics tracks proxy usage statistics
type ProxyMetrics struct {
RequestCount int64
ErrorCount int64
TotalResponseTime time.Duration
// Cache metrics (provider-agnostic counters)
CacheHits int64 // Cache hits (responses served from cache)
CacheMisses int64 // Cache misses (responses fetched from upstream)
CacheBypass int64 // Cache bypassed (e.g., due to authorization)
CacheStores int64 // Cache stores (responses stored in cache)
mu sync.Mutex
}
// CacheMetricType represents the kind of cache metric to increment.
type CacheMetricType int
const (
CacheMetricHit CacheMetricType = iota
CacheMetricMiss
CacheMetricBypass
CacheMetricStore
)
// Metrics returns a pointer to the current proxy metrics.
func (p *TransparentProxy) Metrics() *ProxyMetrics {
p.metrics.mu.Lock()
defer p.metrics.mu.Unlock()
return p.metrics
}
// SetMetrics overwrites the current metrics (primarily for testing).
func (p *TransparentProxy) SetMetrics(m *ProxyMetrics) {
p.mu.Lock()
defer p.mu.Unlock()
p.metrics = m
}
// Cache returns the HTTP cache instance for management operations.
// Returns nil if caching is disabled.
func (p *TransparentProxy) Cache() httpCache {
p.mu.RLock()
defer p.mu.RUnlock()
return p.cache
}
// SetCacheStatsAggregator sets the cache stats aggregator for per-token cache hit tracking.
func (p *TransparentProxy) SetCacheStatsAggregator(agg *CacheStatsAggregator) {
p.mu.Lock()
defer p.mu.Unlock()
p.cacheStatsAggregator = agg
}
// isVaryCompatible reports whether a cached response with a given Vary header
// is valid for the current request and lookup key.
func isVaryCompatible(r *http.Request, cr cachedResponse, lookupKey string) bool {
if cr.vary == "" || cr.vary == "*" {
return true
}
varyKey := CacheKeyFromRequestWithVary(r, cr.vary)
return varyKey == lookupKey
}
// storageKeyForResponse returns the cache storage key to use for a response,
// based on the upstream Vary header. Falls back to the lookup key when Vary is empty or '*'.
func storageKeyForResponse(r *http.Request, varyHeader string, lookupKey string) string {
if varyHeader != "" && varyHeader != "*" {
return CacheKeyFromRequestWithVary(r, varyHeader)
}
return lookupKey
}
// incrementCacheMetric safely increments the specified cache metric counter.
func (p *TransparentProxy) incrementCacheMetric(metric CacheMetricType) {
p.metrics.mu.Lock()
defer p.metrics.mu.Unlock()
switch metric {
case CacheMetricHit:
p.metrics.CacheHits++
case CacheMetricMiss:
p.metrics.CacheMisses++
case CacheMetricBypass:
p.metrics.CacheBypass++
case CacheMetricStore:
p.metrics.CacheStores++
}
}
// NewTransparentProxy creates a new proxy instance with an internally
// configured logger based on the provided ProxyConfig.
func NewTransparentProxy(config ProxyConfig, validator TokenValidator, store ProjectStore) (*TransparentProxy, error) {
logger, err := logging.NewLogger(config.LogLevel, config.LogFormat, config.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
return NewTransparentProxyWithLogger(config, validator, store, logger)
}
// NewTransparentProxyWithObservability creates a new proxy with observability middleware.
func NewTransparentProxyWithObservability(config ProxyConfig, validator TokenValidator, store ProjectStore, obsCfg middleware.ObservabilityConfig) (*TransparentProxy, error) {
logger, err := logging.NewLogger(config.LogLevel, config.LogFormat, config.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
return NewTransparentProxyWithLoggerAndObservability(config, validator, store, logger, obsCfg)
}
// NewTransparentProxyWithLogger allows providing a custom logger. If logger is nil
// a new one is created based on the ProxyConfig.
func NewTransparentProxyWithLogger(config ProxyConfig, validator TokenValidator, store ProjectStore, logger *zap.Logger) (*TransparentProxy, error) {
if logger == nil {
var err error
logger, err = logging.NewLogger(config.LogLevel, config.LogFormat, config.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
}
// Precompute allowed methods header
allowedMethodsHeader := "GET, POST, PUT, PATCH, DELETE, OPTIONS"
if len(config.AllowedMethods) > 0 {
allowedMethodsHeader = strings.Join(config.AllowedMethods, ", ")
}
proxy := &TransparentProxy{
config: config,
tokenValidator: validator,
projectStore: store,
logger: logger,
metrics: &ProxyMetrics{},
allowedMethodsHeader: allowedMethodsHeader,
}
// Initialize HTTP cache (enabled only when HTTPCacheEnabled is true)
if !config.HTTPCacheEnabled {
logger.Info("HTTP cache disabled")
proxy.cache = nil
} else {
if config.RedisCacheURL != "" {
if opt, err := redis.ParseURL(config.RedisCacheURL); err == nil {
client := redis.NewClient(opt)
proxy.cache = newRedisCache(client, config.RedisCacheKeyPrefix)
logger.Info("HTTP cache enabled", zap.String("backend", "redis"))
} else {
proxy.cache = newInMemoryCache()
logger.Warn("Failed to parse RedisCacheURL; falling back to in-memory cache", zap.Error(err))
}
} else {
proxy.cache = newInMemoryCache()
logger.Info("HTTP cache enabled", zap.String("backend", "in-memory"))
}
}
// Initialize the reverse proxy
reverseProxy := &httputil.ReverseProxy{
Director: proxy.director,
ModifyResponse: proxy.modifyResponse,
ErrorHandler: proxy.errorHandler,
Transport: proxy.createTransport(),
FlushInterval: config.FlushInterval,
}
proxy.proxy = reverseProxy
return proxy, nil
}
// NewTransparentProxyWithLoggerAndObservability creates a proxy with observability middleware using an existing logger.
func NewTransparentProxyWithLoggerAndObservability(config ProxyConfig, validator TokenValidator, store ProjectStore, logger *zap.Logger, obsCfg middleware.ObservabilityConfig) (*TransparentProxy, error) {
p, err := NewTransparentProxyWithLogger(config, validator, store, logger)
if err != nil {
return nil, err
}
p.obsMiddleware = middleware.NewObservabilityMiddleware(obsCfg, logger)
return p, nil
}
// NewTransparentProxyWithAudit creates a proxy with audit logging capabilities.
func NewTransparentProxyWithAudit(config ProxyConfig, validator TokenValidator, store ProjectStore, logger *zap.Logger, auditLogger AuditLogger, obsCfg middleware.ObservabilityConfig) (*TransparentProxy, error) {
p, err := NewTransparentProxyWithLoggerAndObservability(config, validator, store, logger, obsCfg)
if err != nil {
return nil, err
}
p.auditLogger = auditLogger
return p, nil
}
// director is the Director function for the reverse proxy
func (p *TransparentProxy) director(req *http.Request) {
// Store original path in context for logging
*req = *req.WithContext(context.WithValue(req.Context(), ctxKeyOriginalPath, req.URL.Path))
targetURL, err := url.Parse(p.config.TargetBaseURL)
if err != nil {
p.logger.Error("Failed to parse target URL", zap.Error(err))
return
}
// Update request URL
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.Host = targetURL.Host
// Add proxy identification headers
req.Header.Set("X-Proxy", "llm-proxy")
req.Header.Set("X-Proxy-Version", version)
if pid, ok := req.Context().Value(ctxKeyProjectID).(string); ok {
req.Header.Set("X-Proxy-ID", pid)
}
// Preserve or strip certain headers
p.processRequestHeaders(req)
requestID, _ := req.Context().Value(ctxKeyRequestID).(string)
p.logger.Debug("Proxying request",
zap.String("request_id", requestID),
zap.String("method", req.Method),
zap.String("path", req.URL.Path),
zap.String("project_id", req.Header.Get("X-Proxy-ID")))
// Verbose upstream request logging
if !p.logger.Core().Enabled(zap.DebugLevel) {
return
}
headers := make(map[string][]string)
for k, v := range req.Header {
headers[k] = v
}
p.logger.Debug("Upstream request",
zap.String("request_id", requestID),
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
zap.Any("headers", headers),
)
// --- PATCH: Add X-UPSTREAM-REQUEST-START header ---
upstreamStart := time.Now().UnixNano()
req.Header.Set("X-UPSTREAM-REQUEST-START", fmt.Sprintf("%d", upstreamStart))
}
// processRequestHeaders handles the manipulation of request headers
func (p *TransparentProxy) processRequestHeaders(req *http.Request) {
// Headers to remove for security/privacy reasons
headersToRemove := []string{
"X-Forwarded-For", // We'll set this ourselves if needed
"X-Real-IP", // Remove client IP for privacy
"CF-Connecting-IP", // Cloudflare headers
"CF-IPCountry", // Cloudflare headers
"X-Client-IP", // Other proxies
"X-Original-Forwarded-For", // Chain of proxies
}
// Remove headers that shouldn't be passed to the upstream
for _, header := range headersToRemove {
req.Header.Del(header)
}
// Set X-Forwarded-For if configured to do so
if p.config.SetXForwardedFor {
// Get the client IP
clientIP := req.RemoteAddr
// Remove port if present
if idx := strings.LastIndex(clientIP, ":"); idx != -1 {
clientIP = clientIP[:idx]
}
req.Header.Set("X-Forwarded-For", clientIP)
}
// If Content-Length is 0 and there's a body, let Go calculate the correct Content-Length
if req.ContentLength == 0 && req.Body != nil {
req.Header.Del("Content-Length")
}
// Ensure proper Accept header for SSE streaming if needed
if isStreamingRequest(req) && req.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
// calculateCacheTTL determines the effective TTL for caching a response.
// It prefers TTL from response headers (when the response is cacheable),
// otherwise falls back to client-forced TTL from the request. It returns the
// chosen TTL and whether it came from the response headers.
func calculateCacheTTL(res *http.Response, req *http.Request, defaultTTL time.Duration) (time.Duration, bool) {
if res == nil || req == nil {
return 0, false
}
respTTL := cacheTTLFromHeaders(res, defaultTTL)
if respTTL > 0 {
if !isResponseCacheable(res) {
return 0, false
}
return respTTL, true
}
forcedTTL := requestForcedCacheTTL(req)
if forcedTTL > 0 {
return forcedTTL, false
}
return 0, false
}
func (p *TransparentProxy) modifyResponse(res *http.Response) error {
// For streaming responses, return early without side effects
if isStreaming(res) {
return nil
}
// Set proxy headers
res.Header.Set("X-Proxy", "llm-proxy")
// Process response body to extract metadata for non-streaming responses
if res.StatusCode == http.StatusOK &&
strings.Contains(res.Header.Get("Content-Type"), "application/json") &&
res.Body != nil {
// Extract metadata without consuming the response body
if err := p.extractResponseMetadata(res); err != nil {
p.logger.Warn("Failed to extract response metadata", zap.Error(err))
}
}
// Update metrics
p.metrics.mu.Lock()
p.metrics.RequestCount++
if res.StatusCode >= 400 {
p.metrics.ErrorCount++
}
p.metrics.mu.Unlock()
// --- PATCH: Copy X-UPSTREAM-REQUEST-START from request to response ---
if res.Request != nil {
if v := res.Request.Header.Get("X-UPSTREAM-REQUEST-START"); v != "" {
res.Header.Set("X-UPSTREAM-REQUEST-START", v)
}
}
// Store in cache when enabled and request is cacheable
if p.cache != nil && res.Request != nil {
req := res.Request
if req.Method == http.MethodGet || req.Method == http.MethodHead || req.Method == http.MethodPost {
// Only cache successful responses
if res.StatusCode < 200 || res.StatusCode >= 300 {
res.Header.Set("X-CACHE-DEBUG", "status-not-cacheable")
return nil
}
// Calculate effective TTL
ttl, fromResponse := calculateCacheTTL(res, req, p.config.HTTPCacheDefaultTTL)
if ttl <= 0 {
res.Header.Set("X-CACHE-DEBUG", fmt.Sprintf("ttl-zero-ttl=%v-from-resp=%v", ttl, fromResponse))
return nil
}
// Ensure Cache-Status preserves miss set by handler
if res.Header.Get("Cache-Status") == "" {
res.Header.Set("Cache-Status", "llm-proxy; miss")
}
key := CacheKeyFromRequest(req)
// Compute storage key via helper to respect Vary
storageKey := storageKeyForResponse(req, res.Header.Get("Vary"), key)
if !isStreaming(res) {
bodyBytes, err := io.ReadAll(res.Body)
if err == nil {
_ = res.Body.Close()
res.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if p.config.HTTPCacheMaxObjectBytes == 0 || int64(len(bodyBytes)) <= p.config.HTTPCacheMaxObjectBytes {
headers := cloneHeadersForCache(res.Header)
if !fromResponse {
headers.Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(ttl.Seconds())))
}
// Store the Vary header for per-response cache key generation
varyValue := res.Header.Get("Vary")
cr := cachedResponse{
statusCode: res.StatusCode,
headers: headers,
body: bodyBytes,
expiresAt: time.Now().Add(ttl),
vary: varyValue,
}
p.cache.Set(storageKey, cr)
res.Header.Set("X-PROXY-CACHE", "stored")
res.Header.Set("X-PROXY-CACHE-KEY", storageKey)
p.incrementCacheMetric(CacheMetricStore)
if !fromResponse {
res.Header.Set("Cache-Status", "llm-proxy; stored (forced)")
} else {
res.Header.Set("Cache-Status", "llm-proxy; stored")
}
}
} else {
res.Header.Set("X-CACHE-DEBUG", fmt.Sprintf("read-body-error=%v", err))
}
} else {
res.Header.Set("X-CACHE-DEBUG", "streaming-response")
maxBytes := p.config.HTTPCacheMaxObjectBytes
if maxBytes <= 0 {
maxBytes = 2 * 1024 * 1024 // default 2MB
}
headers := cloneHeadersForCache(res.Header)
if !fromResponse {
headers.Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(ttl.Seconds())))
}
// Store the Vary header for per-response cache key generation
varyValue := res.Header.Get("Vary")
// Compute storage key via helper
storageKey := storageKeyForResponse(req, varyValue, key)
expiresAt := time.Now().Add(ttl)
orig := res.Body
res.Body = newStreamingCapture(orig, maxBytes, func(buf []byte) {
if len(buf) == 0 {
return
}
if int64(len(buf)) > maxBytes {
return
}
p.cache.Set(storageKey, cachedResponse{
statusCode: res.StatusCode,
headers: headers,
body: append([]byte(nil), buf...),
expiresAt: expiresAt,
vary: varyValue,
})
p.incrementCacheMetric(CacheMetricStore)
})
}
}
}
// Set miss status if no cache status was set
if res.Header.Get("Cache-Status") == "" {
res.Header.Set("Cache-Status", "llm-proxy; miss")
}
return nil
}
// extractResponseMetadata extracts metadata from the response body without consuming it
func (p *TransparentProxy) extractResponseMetadata(res *http.Response) error {
// Check if we need to process the response
if res.Body == nil {
return errors.New("response body is nil")
}
// Only parse as JSON if Content-Type is application/json and not compressed
contentType := res.Header.Get("Content-Type")
contentEncoding := res.Header.Get("Content-Encoding")
if !strings.Contains(contentType, "application/json") || (contentEncoding != "" && contentEncoding != "identity") {
p.logger.Debug("Skipping metadata extraction: not JSON or compressed",
zap.String("content_type", contentType),
zap.String("content_encoding", contentEncoding))
return nil
}
// We need to read the body to extract metadata, but we must also
// preserve it for the client. This is done by creating a new Reader
// that allows us to read the body twice.
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
// Replace the body with a new ReadCloser that can be read again
err = res.Body.Close()
if err != nil {
return fmt.Errorf("failed to close response body: %w", err)
}
res.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Parse the body to extract metadata
metadata, err := p.parseOpenAIResponseMetadata(bodyBytes)
if err != nil {
p.logger.Debug("Failed to extract response metadata",
zap.Error(err),
zap.String("content_type", contentType),
zap.String("content_encoding", contentEncoding))
return nil
}
// Add metadata to response headers
for k, v := range metadata {
res.Header.Set(fmt.Sprintf("X-OpenAI-%s", k), v)
}
return nil
}
// parseOpenAIResponseMetadata extracts metadata from OpenAI API responses
func (p *TransparentProxy) parseOpenAIResponseMetadata(bodyBytes []byte) (map[string]string, error) {
metadata := make(map[string]string)
// Try to parse as JSON
var result map[string]interface{}
if err := json.Unmarshal(bodyBytes, &result); err != nil {
return metadata, fmt.Errorf("failed to parse response JSON: %w", err)
}
// Look for usage information
if usage, ok := result["usage"].(map[string]interface{}); ok {
// Extract token counts
if promptTokens, ok := usage["prompt_tokens"].(float64); ok {
metadata["Prompt-Tokens"] = fmt.Sprintf("%.0f", promptTokens)
}
if completionTokens, ok := usage["completion_tokens"].(float64); ok {
metadata["Completion-Tokens"] = fmt.Sprintf("%.0f", completionTokens)
}
if totalTokens, ok := usage["total_tokens"].(float64); ok {
metadata["Total-Tokens"] = fmt.Sprintf("%.0f", totalTokens)
}
}
// Extract model information
if model, ok := result["model"].(string); ok {
metadata["Model"] = model
}
// Extract other potentially useful metadata
if id, ok := result["id"].(string); ok {
metadata["ID"] = id
}
if created, ok := result["created"].(float64); ok {
metadata["Created"] = fmt.Sprintf("%.0f", created)
}
return metadata, nil
}
// errorHandler handles errors that occur during proxying
func (p *TransparentProxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
// Check if there was a validation error
if validationErr, ok := r.Context().Value(ctxKeyValidationError).(error); ok {
p.handleValidationError(w, r, validationErr)
return
}
// Handle different error types
requestID, _ := r.Context().Value(ctxKeyRequestID).(string)
p.logger.Error("Proxy error",
zap.String("request_id", requestID),
zap.Error(err),
zap.String("method", r.Method),
zap.String("path", r.URL.Path))
statusCode := http.StatusBadGateway
errorResponse := ErrorResponse{
Error: "Proxy error",
}
switch {
case errors.Is(err, context.DeadlineExceeded):
statusCode = http.StatusGatewayTimeout
errorResponse.Error = "Request timeout"
errorResponse.Code = "timeout"
case errors.Is(err, context.Canceled):
statusCode = http.StatusRequestTimeout
errorResponse.Error = "Request canceled"
errorResponse.Code = "canceled"
default:
// Use default values
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
}
// handleValidationError handles errors specific to token validation
func (p *TransparentProxy) handleValidationError(w http.ResponseWriter, r *http.Request, err error) {
// Get request ID and token directly from the request
requestID, _ := r.Context().Value(ctxKeyRequestID).(string)
var obfuscatedToken string
authHeader := r.Header.Get("Authorization")
if len(authHeader) > 7 && strings.HasPrefix(authHeader, "Bearer ") {
tok := strings.TrimSpace(authHeader[7:])
obfuscatedToken = token.ObfuscateToken(tok)
}
statusCode := http.StatusUnauthorized
errorResponse := ErrorResponse{
Error: "Authentication error",
}
// Check for specific token errors
switch {
case errors.Is(err, token.ErrTokenNotFound):
errorResponse.Error = "Token not found"
errorResponse.Code = "token_not_found"
case errors.Is(err, token.ErrTokenInactive):
errorResponse.Error = "Token is inactive"
errorResponse.Code = "token_inactive"
case errors.Is(err, token.ErrTokenExpired):
errorResponse.Error = "Token has expired"
errorResponse.Code = "token_expired"
case errors.Is(err, token.ErrTokenRateLimit):
statusCode = http.StatusTooManyRequests
errorResponse.Error = "Rate limit exceeded"
errorResponse.Code = "rate_limit_exceeded"
default:
errorResponse.Error = "Invalid token"
errorResponse.Description = err.Error()
errorResponse.Code = "invalid_token"
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
p.logger.Error("Validation error",
zap.String("request_id", requestID),
zap.Error(err),
zap.String("token", obfuscatedToken),
)
}
// createTransport creates an HTTP transport with appropriate settings
func (p *TransparentProxy) createTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: p.config.MaxIdleConns,
MaxIdleConnsPerHost: p.config.MaxIdleConnsPerHost,
IdleConnTimeout: p.config.IdleConnTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: p.config.ResponseHeaderTimeout,
}
}
// Handler returns the HTTP handler for the proxy
func (p *TransparentProxy) Handler() http.Handler {
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Short-circuit OPTIONS requests: no auth required, respond with 204 and CORS headers
if r.Method == http.MethodOptions {
// Set CORS headers for preflight requests
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
if reqHeaders := r.Header.Get("Access-Control-Request-Headers"); reqHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
} else {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With")
}
w.Header().Set("Access-Control-Expose-Headers", "X-Request-ID, X-Proxy-ID, X-LLM-Proxy-Remote-Duration, X-LLM-Proxy-Remote-Duration-Ms")
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
}
w.WriteHeader(http.StatusNoContent)
return
}
// Generate request ID and add to context
requestID := uuid.New().String()
ctx := context.WithValue(r.Context(), ctxKeyRequestID, requestID)
// Also add to shared logging context so middlewares can read it
ctx = logging.WithRequestID(ctx, requestID)
r = r.WithContext(ctx)
// Set request header so observability/file logger can capture it (response header is still set in ModifyResponse only)
r.Header.Set("X-Request-ID", requestID)
// Record when proxy receives the request
receivedAt := time.Now().UTC()
ctx = context.WithValue(ctx, ctxKeyProxyReceivedAt, receivedAt)
r = r.WithContext(ctx)
// --- Token extraction and validation (moved from director) ---
authHeader := r.Header.Get("Authorization")
tokenStr := extractTokenFromHeader(authHeader)
if tokenStr == "" {
p.handleValidationError(w, r, errors.New("missing or invalid authorization header"))
return
}
projectID, err := p.tokenValidator.ValidateTokenWithTracking(r.Context(), tokenStr)
if err != nil {
p.handleValidationError(w, r, err)
return
}
ctx = context.WithValue(r.Context(), ctxKeyProjectID, projectID)
ctx = context.WithValue(ctx, ctxKeyTokenID, tokenStr)
r = r.WithContext(ctx)
apiKey, err := p.projectStore.GetAPIKeyForProject(r.Context(), projectID)
if err != nil {
p.handleValidationError(w, r, fmt.Errorf("failed to get API key: %w", err))
return
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
// Enforce project active status using shared helper (if enabled)
if allowed, status, er := shouldAllowProject(r.Context(), p.config.EnforceProjectActive, p.projectStore, projectID, p.auditLogger, r); !allowed {
writeErrorResponse(w, status, er)
return
}
// Wrap the ResponseWriter to allow us to set headers at first/last byte
rw := &timingResponseWriter{ResponseWriter: w}
// Instrument the reverse proxy (director now only rewrites URL/host)
p.proxy.Director = func(req *http.Request) {
// Store original path in context for logging
*req = *req.WithContext(context.WithValue(req.Context(), ctxKeyOriginalPath, req.URL.Path))
targetURL, err := url.Parse(p.config.TargetBaseURL)
if err != nil {
p.logger.Error("Failed to parse target URL", zap.Error(err))
return
}
// Update request URL
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.Host = targetURL.Host
// Add proxy identification headers
req.Header.Set("X-Proxy", "llm-proxy")
req.Header.Set("X-Proxy-Version", version)
if pid, ok := req.Context().Value(ctxKeyProjectID).(string); ok {
req.Header.Set("X-Proxy-ID", pid)
}
// Preserve or strip certain headers
p.processRequestHeaders(req)
requestID, _ := req.Context().Value(ctxKeyRequestID).(string)
p.logger.Debug("Proxying request",
zap.String("request_id", requestID),
zap.String("method", req.Method),
zap.String("path", req.URL.Path),
zap.String("project_id", req.Header.Get("X-Proxy-ID")))
// --- PATCH: Add X-UPSTREAM-REQUEST-START header ---
upstreamStart := time.Now().UnixNano()
req.Header.Set("X-UPSTREAM-REQUEST-START", fmt.Sprintf("%d", upstreamStart))
}
p.proxy.ModifyResponse = func(res *http.Response) error {
firstRespAt := time.Now().UTC()
ctx := context.WithValue(res.Request.Context(), ctxKeyProxyFirstRespAt, firstRespAt)
res.Request = res.Request.WithContext(ctx)
ctx = context.WithValue(res.Request.Context(), ctxKeyProxyFinalRespAt, firstRespAt)
res.Request = res.Request.WithContext(ctx)
setTimingHeaders(res, res.Request.Context())
requestID, _ := res.Request.Context().Value(ctxKeyRequestID).(string)
if requestID != "" {
res.Header.Set("X-Request-ID", requestID)
}
logProxyTimings(p.logger, res.Request.Context())
// --- PATCH: Add X-UPSTREAM-REQUEST-STOP header ---
upstreamStop := time.Now().UnixNano()
res.Header.Set("X-UPSTREAM-REQUEST-STOP", fmt.Sprintf("%d", upstreamStop))
return p.modifyResponse(res)
}
p.proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
logProxyTimings(p.logger, r.Context())
p.errorHandler(w, r, err)
}
// Compute X-Body-Hash once for methods with bodies to support POST/PUT/PATCH caching
if r.Body != nil && (r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch) {
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Import hashing lazily to avoid overhead elsewhere
sum := sha256.Sum256(bodyBytes)
r.Header.Set("X-Body-Hash", hex.EncodeToString(sum[:]))
}
// Simple cache lookup with conditional handling (ETag/Last-Modified)
if p.cache != nil && (r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodPost) {
// Allow GET/HEAD lookups by default when cache is enabled, since reuse will still be gated by canServeCachedForRequest.
// Require explicit client opt-in for POST lookups.
optIn := hasClientCacheOptIn(r)
allowedLookup := (r.Method == http.MethodGet || r.Method == http.MethodHead) || (r.Method == http.MethodPost && optIn)
if !allowedLookup {
// Cache is enabled but this request type/method is not cacheable - count as miss
p.recordCacheMiss()
p.proxy.ServeHTTP(rw, r)
return
}
key := CacheKeyFromRequest(r)
if cr, ok := p.cache.Get(key); ok {
// Validate Vary compatibility using helper
if !isVaryCompatible(r, cr, key) {
p.recordCacheMiss()
// Note: don't set miss status here; let modifyResponse handle cache status
p.proxy.ServeHTTP(rw, r)
return
}
if !canServeCachedForRequest(r, cr.headers) {
// Authorization present but cached response not explicitly shared-cacheable
w.Header().Set("Cache-Status", "llm-proxy; bypass")
w.Header().Set("X-PROXY-CACHE", "bypass")
w.Header().Set("X-PROXY-CACHE-KEY", key)
p.incrementCacheMetric(CacheMetricBypass)
p.proxy.ServeHTTP(rw, r)
return
}
// Origin revalidation path: if client requests revalidation (no-cache/max-age=0),
// send conditional request upstream using cached validators (ETag/Last-Modified).
if wantsRevalidation(r) {
condReq := r.Clone(r.Context())
if etag := cr.headers.Get("ETag"); etag != "" {
condReq.Header.Set("If-None-Match", etag)
}
if lm := cr.headers.Get("Last-Modified"); lm != "" {
condReq.Header.Set("If-Modified-Since", lm)
}
// Forward conditionally to upstream; let modifyResponse handle store/refresh
// Don't increment miss here since this is a conditional revalidation
p.proxy.ServeHTTP(rw, condReq)
return
}
// If the client provided conditionals, respond 304 when validators match
if r.Method == http.MethodGet || r.Method == http.MethodHead {
if conditionalRequestMatches(r, cr.headers) {
for hk, hv := range cr.headers {
for _, v := range hv {
w.Header().Add(hk, v)
}
}
w.Header().Set("Cache-Status", "llm-proxy; conditional-hit")
w.Header().Set("X-PROXY-CACHE", "conditional-hit")
w.Header().Set("X-PROXY-CACHE-KEY", key)
p.recordCacheHit(r) // Conditional hit counts as cache hit
w.WriteHeader(http.StatusNotModified)
return
}
}
for hk, hv := range cr.headers {
for _, v := range hv {
w.Header().Add(hk, v)
}
}
w.Header().Set("Cache-Status", "llm-proxy; hit")
w.Header().Set("X-PROXY-CACHE", "hit")
w.Header().Set("X-PROXY-CACHE-KEY", key)
p.recordCacheHit(r)
w.WriteHeader(cr.statusCode)
if r.Method != http.MethodHead {
_, _ = w.Write(cr.body)
}
return
}
// Cache miss - no entry found
p.recordCacheMiss()
// Note: don't set miss status here; let modifyResponse handle cache status
// w.Header().Set("Cache-Status", "llm-proxy; miss")
// Do not set X-PROXY-CACHE(-KEY) on miss; only set definitive headers on hit/bypass/conditional-hit or store path
} else if p.cache != nil {
// Cache is enabled but method is not cacheable (e.g., DELETE, OPTIONS, etc.) - count as miss
p.recordCacheMiss()
}
p.proxy.ServeHTTP(rw, r)
})
var handler http.Handler = baseHandler
handler = p.ValidateRequestMiddleware()(handler)
if p.obsMiddleware != nil {
handler = p.obsMiddleware.Middleware()(handler)
}
handler = CircuitBreakerMiddleware(5, 30*time.Second, func(status int) bool {
return status == http.StatusBadGateway || status == http.StatusServiceUnavailable || status == http.StatusGatewayTimeout
})(handler)
return handler
}
type timingResponseWriter struct {
http.ResponseWriter
firstByteOnce sync.Once
firstByteAt time.Time
finalByteAt time.Time
}
func (w *timingResponseWriter) Write(b []byte) (int, error) {
now := time.Now().UTC()
w.firstByteOnce.Do(func() {
w.firstByteAt = now
w.Header().Set("X-Proxy-First-Response-At", w.firstByteAt.Format(time.RFC3339Nano))
})
w.finalByteAt = now
return w.ResponseWriter.Write(b)
}
func (w *timingResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// recordCacheMiss centralizes cache miss accounting to reduce duplication and
// ensure consistent metric semantics across all miss paths.
func (p *TransparentProxy) recordCacheMiss() {
p.incrementCacheMetric(CacheMetricMiss)
}
// recordCacheHit records a cache hit for metrics and per-token tracking.
func (p *TransparentProxy) recordCacheHit(r *http.Request) {
p.incrementCacheMetric(CacheMetricHit)
// Record per-token cache hit if aggregator is configured
if p.cacheStatsAggregator != nil {
if tokenID, ok := r.Context().Value(ctxKeyTokenID).(string); ok && tokenID != "" {
p.cacheStatsAggregator.RecordCacheHit(tokenID)
}
}
}
func setTimingHeaders(res *http.Response, ctx context.Context) {
if v := ctx.Value(ctxKeyProxyReceivedAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-Received-At", t.Format(time.RFC3339Nano))
}
}
if v := ctx.Value(ctxKeyProxySentBackendAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-Sent-Backend-At", t.Format(time.RFC3339Nano))
}
}
if v := ctx.Value(ctxKeyProxyFirstRespAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-First-Response-At", t.Format(time.RFC3339Nano))
}
}
if v := ctx.Value(ctxKeyProxyFinalRespAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-Final-Response-At", t.Format(time.RFC3339Nano))
}
}
}
func logProxyTimings(logger *zap.Logger, ctx context.Context) {
received, _ := ctx.Value(ctxKeyProxyReceivedAt).(time.Time)
sent, _ := ctx.Value(ctxKeyProxySentBackendAt).(time.Time)
first, _ := ctx.Value(ctxKeyProxyFirstRespAt).(time.Time)
final, _ := ctx.Value(ctxKeyProxyFinalRespAt).(time.Time)
requestID, _ := ctx.Value(ctxKeyRequestID).(string)
if !received.IsZero() && !sent.IsZero() {
logger.Debug("Proxy overhead (pre-backend)", zap.Duration("duration", sent.Sub(received)), zap.String("request_id", requestID))
}
if !sent.IsZero() && !first.IsZero() {
logger.Debug("Backend latency (first byte)", zap.Duration("duration", first.Sub(sent)), zap.String("request_id", requestID))
}
if !first.IsZero() && !final.IsZero() {
logger.Debug("Streaming duration", zap.Duration("duration", final.Sub(first)), zap.String("request_id", requestID))
}
if !received.IsZero() && !final.IsZero() {
logger.Debug("Total proxy duration", zap.Duration("duration", final.Sub(received)), zap.String("request_id", requestID))
}
}
// LoggingMiddleware logs request details
func (p *TransparentProxy) LoggingMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
requestID, _ := r.Context().Value(ctxKeyRequestID).(string)
p.logger.Info("Request started",
zap.String("request_id", requestID),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("remote_addr", r.RemoteAddr))
// Create a response recorder to capture response details
rec := &responseRecorder{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// Process request
next.ServeHTTP(rec, r)
// Log request completion
duration := time.Since(start)
p.logger.Info("Request completed",
zap.String("request_id", requestID),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.Int("status", rec.statusCode),
zap.Duration("duration", duration))
})
}
}
// ValidateRequestMiddleware validates the incoming request against allowed endpoints and methods
func (p *TransparentProxy) ValidateRequestMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Ensure request_id is set in context as the very first step
requestID, ok := r.Context().Value(ctxKeyRequestID).(string)
if !ok || requestID == "" {
requestID = uuid.New().String()
r = r.WithContext(context.WithValue(r.Context(), ctxKeyRequestID, requestID))
}
// --- Validation Scope: Only token, path, and method are validated here ---
// Do not add API-specific validation or transformation logic here.
// Check if method is allowed
if !p.isMethodAllowed(r.Method) {
p.logger.Warn("Method not allowed",
zap.String("method", r.Method),
zap.String("path", r.URL.Path))
w.WriteHeader(http.StatusMethodNotAllowed)
if requestID != "" {
w.Header().Set("X-Request-ID", requestID)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Method not allowed",
Code: "method_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
// Check if endpoint is allowed
if !p.isEndpointAllowed(r.URL.Path) {
p.logger.Warn("Endpoint not allowed",
zap.String("method", r.Method),
zap.String("path", r.URL.Path))
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Endpoint not found",
Code: "endpoint_not_found",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
// --- End of validation scope ---
// --- Begin param whitelist validation ---
if r.Method == http.MethodPost && len(p.config.ParamWhitelist) > 0 && r.Header.Get("Content-Type") == "application/json" {
// Read and buffer the body for validation and later proxying
var bodyBytes []byte
if r.Body != nil {
bodyBytes, _ = io.ReadAll(r.Body)
}
if len(bodyBytes) > 0 {
var bodyMap map[string]interface{}
if err := json.Unmarshal(bodyBytes, &bodyMap); err == nil {
for param, allowed := range p.config.ParamWhitelist {
if val, ok := bodyMap[param]; ok {
valStr := ""
switch v := val.(type) {
case string:
valStr = v
case float64:
valStr = fmt.Sprintf("%v", v)
default:
valStr = fmt.Sprintf("%v", v)
}
found := false
// Support glob expressions in allowed values
for _, allowedVal := range allowed {
if ok, _ := path.Match(allowedVal, valStr); ok {
found = true
break
}
}
if !found {
w.WriteHeader(http.StatusBadRequest)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: fmt.Sprintf("Parameter '%s' value '%s' is not allowed. Allowed patterns: %v", param, valStr, allowed),
Code: "param_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
}
}
}
// Restore the body for downstream handlers
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
// --- End param whitelist validation ---
// --- Begin CORS origin validation ---
origin := r.Header.Get("Origin")
originRequired := false
for _, h := range p.config.RequiredHeaders {
if strings.EqualFold(h, "origin") {
originRequired = true
break
}
}
if originRequired {
if origin == "" {
w.WriteHeader(http.StatusBadRequest)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Origin header required",
Code: "origin_required",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
if len(p.config.AllowedOrigins) > 0 {
allowed := false
for _, o := range p.config.AllowedOrigins {
if o == origin {
allowed = true
break
}
}
if !allowed {
w.WriteHeader(http.StatusForbidden)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Origin not allowed",
Code: "origin_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
}
} else if origin != "" && len(p.config.AllowedOrigins) > 0 {
allowed := false
for _, o := range p.config.AllowedOrigins {
if o == origin {
allowed = true
break
}
}
if !allowed {
w.WriteHeader(http.StatusForbidden)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Origin not allowed",
Code: "origin_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
}
// --- End CORS origin validation ---
// Continue to next middleware
next.ServeHTTP(w, r)
})
}
}
// TimeoutMiddleware adds a timeout to requests
func (p *TransparentProxy) TimeoutMiddleware(timeout time.Duration) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// MetricsMiddleware collects metrics about requests
func (p *TransparentProxy) MetricsMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Store start time in context
ctx := context.WithValue(r.Context(), ctxKeyRequestStart, start)
// Create response recorder to capture status code
rec := &responseRecorder{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// Process request
next.ServeHTTP(rec, r.WithContext(ctx))
// Record metrics
duration := time.Since(start)
p.metrics.mu.Lock()
p.metrics.RequestCount++
// Increment error count for status codes >= 400
if rec.statusCode >= 400 {
p.metrics.ErrorCount++
}
p.metrics.TotalResponseTime += duration
p.metrics.mu.Unlock()
})
}
}
// Shutdown gracefully shuts down the proxy
func (p *TransparentProxy) Shutdown(ctx context.Context) error {
p.mu.Lock()
p.shuttingDown = true
p.mu.Unlock()
p.logger.Info("Shutting down proxy")
// If we have an HTTP server, shut it down
if p.httpServer != nil {
return p.httpServer.Shutdown(ctx)
}
return nil
}
// isMethodAllowed checks if a method is in the allowed list
func (p *TransparentProxy) isMethodAllowed(method string) bool {
// If no allowed methods are specified, allow all methods
if len(p.config.AllowedMethods) == 0 {
return true
}
for _, allowed := range p.config.AllowedMethods {
if strings.EqualFold(method, allowed) {
return true
}
}
return false
}
// isEndpointAllowed checks if an endpoint is in the allowed list
func (p *TransparentProxy) isEndpointAllowed(path string) bool {
// If no allowed endpoints are specified, allow all endpoints
if len(p.config.AllowedEndpoints) == 0 {
return true
}
// Check if path matches any allowed endpoint
for _, endpoint := range p.config.AllowedEndpoints {
if strings.HasPrefix(path, endpoint) {
return true
}
}
return false
}
// responseRecorder is a wrapper for http.ResponseWriter that records the status code
type responseRecorder struct {
http.ResponseWriter
statusCode int
}
// WriteHeader records the status code and calls the wrapped ResponseWriter's WriteHeader method
func (r *responseRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
// Add Flush forwarding for streaming support
func (r *responseRecorder) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// extractTokenFromHeader extracts a token from an authorization header
func extractTokenFromHeader(authHeader string) string {
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
return ""
}
return strings.TrimSpace(parts[1])
}
// isStreaming checks if a response is a streaming response
func isStreaming(res *http.Response) bool {
// Check Content-Type for SSE
if strings.Contains(res.Header.Get("Content-Type"), "text/event-stream") {
return true
}
// Check for chunked transfer encoding
return strings.Contains(
strings.ToLower(res.Header.Get("Transfer-Encoding")),
"chunked",
)
}
// isStreamingRequest checks if a request is intended for streaming
func isStreamingRequest(req *http.Request) bool {
// Check for SSE Accept header
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
return true
}
// Check query parameters for stream=true (common in OpenAI APIs)
if req.URL.Query().Get("stream") == "true" {
return true
}
// Check the request body for streaming flag
// This is a heuristic and may need refinement for specific APIs
// For OpenAI, the common pattern is POST with JSON containing "stream": true
// But checking this would require reading the body, which we want to avoid
// We'll just rely on the Accept header and query params for now
return false
}
package proxy
import (
"bytes"
"io"
"sync/atomic"
)
// streamingCaptureReadCloser wraps an io.ReadCloser and captures the bytes
// read into an internal buffer. Once EOF or Close is reached, it invokes
// the provided finalize callback with the captured bytes (if any).
type streamingCaptureReadCloser struct {
rc io.ReadCloser
buf *bytes.Buffer
maxBytes int64
written int64
finalized int32
onDone func([]byte)
}
func newStreamingCapture(rc io.ReadCloser, maxBytes int64, onDone func([]byte)) *streamingCaptureReadCloser {
return &streamingCaptureReadCloser{
rc: rc,
buf: &bytes.Buffer{},
maxBytes: maxBytes,
onDone: onDone,
}
}
func (s *streamingCaptureReadCloser) Read(p []byte) (int, error) {
n, err := s.rc.Read(p)
if n > 0 && (s.maxBytes == 0 || s.written < s.maxBytes) {
limit := n
remaining := s.maxBytes - s.written
if s.maxBytes > 0 && int64(limit) > remaining {
limit = int(remaining)
}
_, _ = s.buf.Write(p[:limit])
s.written += int64(limit)
}
if err == io.EOF {
s.finalize()
}
return n, err
}
func (s *streamingCaptureReadCloser) Close() error {
// Ensure finalize happens once even if Close is called before EOF
s.finalize()
return s.rc.Close()
}
func (s *streamingCaptureReadCloser) finalize() {
if atomic.CompareAndSwapInt32(&s.finalized, 0, 1) {
if s.onDone != nil {
s.onDone(s.buf.Bytes())
}
}
}
// Package server implements the HTTP server for the LLM Proxy.
// It handles request routing, lifecycle management, and provides
// health check endpoints and core API functionality.
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/sofatutor/llm-proxy/internal/audit"
"github.com/sofatutor/llm-proxy/internal/config"
"github.com/sofatutor/llm-proxy/internal/database"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"github.com/sofatutor/llm-proxy/internal/logging"
"github.com/sofatutor/llm-proxy/internal/middleware"
"github.com/sofatutor/llm-proxy/internal/proxy"
"github.com/sofatutor/llm-proxy/internal/token"
"go.uber.org/zap"
)
// Server represents the HTTP server for the LLM Proxy.
// It encapsulates the underlying http.Server along with application configuration
// and handles request routing and server lifecycle management.
type Server struct {
server *http.Server
config *config.Config
tokenStore token.TokenStore
projectStore proxy.ProjectStore
logger *zap.Logger
proxy *proxy.TransparentProxy
metrics Metrics
eventBus eventbus.EventBus
auditLogger *audit.Logger
db *database.DB
cacheStatsAgg *proxy.CacheStatsAggregator
}
// HealthResponse is the response body for the health check endpoint.
// It provides basic information about the server status and version.
type HealthResponse struct {
Status string `json:"status"` // Service status, "ok" for a healthy system
Timestamp time.Time `json:"timestamp"` // Current server time
Version string `json:"version"` // Application version number
}
// Metrics holds runtime metrics for the server.
type Metrics struct {
StartTime time.Time
RequestCount int64
ErrorCount int64
}
// Version is the application version, following semantic versioning.
const Version = "0.1.0"
// maxDurationMinutes is the maximum allowed duration for a token (365 days)
const maxDurationMinutes = 525600
// New creates a new HTTP server with the provided configuration and store implementations.
// It initializes the server with appropriate timeouts and registers all necessary route handlers.
// The server is not started until the Start method is called.
func New(cfg *config.Config, tokenStore token.TokenStore, projectStore proxy.ProjectStore) (*Server, error) {
return NewWithDatabase(cfg, tokenStore, projectStore, nil)
}
// NewWithDatabase creates a new HTTP server with database support for audit logging.
// This allows the server to store audit events in both file and database backends.
func NewWithDatabase(cfg *config.Config, tokenStore token.TokenStore, projectStore proxy.ProjectStore, db *database.DB) (*Server, error) {
mux := http.NewServeMux()
logger, err := logging.NewLogger(cfg.LogLevel, cfg.LogFormat, cfg.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
// Initialize audit logger with optional database backend
var auditLogger *audit.Logger
if cfg.AuditEnabled && cfg.AuditLogFile != "" {
auditConfig := audit.LoggerConfig{
FilePath: cfg.AuditLogFile,
CreateDir: cfg.AuditCreateDir,
DatabaseStore: db, // Database store for audit events
EnableDatabase: cfg.AuditStoreInDB && db != nil,
}
auditLogger, err = audit.NewLogger(auditConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize audit logger: %w", err)
}
if cfg.AuditStoreInDB && db != nil {
logger.Info("Audit logging enabled with database storage", zap.String("log_file", cfg.AuditLogFile))
} else {
logger.Info("Audit logging enabled", zap.String("log_file", cfg.AuditLogFile))
}
} else {
auditLogger = audit.NewNullLogger()
logger.Info("Audit logging disabled")
}
metrics := Metrics{StartTime: time.Now()}
var bus eventbus.EventBus
switch cfg.EventBusBackend {
case "redis":
client := redis.NewClient(&redis.Options{
Addr: cfg.RedisAddr,
DB: cfg.RedisDB,
})
// Redis diagnostics
logger.Info("Connecting to Redis", zap.String("addr", cfg.RedisAddr), zap.Int("db", cfg.RedisDB))
pong, err := client.Ping(context.Background()).Result()
if err != nil {
logger.Fatal("Failed to ping Redis",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.Error(err))
}
logger.Info("Successfully pinged Redis",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.String("response", pong))
err = client.Set(context.Background(), "llm-proxy-debug", "hello", 0).Err()
if err != nil {
logger.Fatal("Failed to set test key in Redis", zap.Error(err))
}
logger.Debug("Successfully set test key in Redis", zap.String("key", "llm-proxy-debug"))
adapter := &eventbus.RedisGoClientAdapter{Client: client}
bus = eventbus.NewRedisEventBusPublisher(adapter, "llm-proxy-events")
logger.Info("Using Redis event bus", zap.String("addr", cfg.RedisAddr), zap.Int("db", cfg.RedisDB))
case "redis-streams":
client := redis.NewClient(&redis.Options{
Addr: cfg.RedisAddr,
DB: cfg.RedisDB,
})
// Redis diagnostics
logger.Info("Connecting to Redis for Streams", zap.String("addr", cfg.RedisAddr), zap.Int("db", cfg.RedisDB))
pong, err := client.Ping(context.Background()).Result()
if err != nil {
logger.Fatal("Failed to ping Redis",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.Error(err))
}
logger.Info("Successfully pinged Redis",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.String("response", pong))
// Generate consumer name if not provided
consumerName := cfg.RedisConsumerName
if consumerName == "" {
consumerName = fmt.Sprintf("proxy-%s", uuid.New().String()[:8])
}
streamsConfig := eventbus.RedisStreamsConfig{
StreamKey: cfg.RedisStreamKey,
ConsumerGroup: cfg.RedisConsumerGroup,
ConsumerName: consumerName,
MaxLen: cfg.RedisStreamMaxLen,
BlockTimeout: cfg.RedisStreamBlockTime,
ClaimMinIdleTime: cfg.RedisStreamClaimTime,
BatchSize: cfg.RedisStreamBatchSize,
}
adapter := &eventbus.RedisStreamsClientAdapter{Client: client}
bus = eventbus.NewRedisStreamsEventBus(adapter, streamsConfig)
logger.Info("Using Redis Streams event bus",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.String("stream", cfg.RedisStreamKey),
zap.String("consumer_group", cfg.RedisConsumerGroup),
zap.String("consumer_name", consumerName))
case "in-memory":
logger.Info("Using in-memory event bus", zap.String("mode", "single-process"))
bus = eventbus.NewInMemoryEventBus(cfg.ObservabilityBufferSize)
default:
return nil, fmt.Errorf("unknown event bus backend: %s", cfg.EventBusBackend)
}
s := &Server{
config: cfg,
tokenStore: tokenStore,
projectStore: projectStore,
logger: logger,
metrics: metrics,
eventBus: bus,
auditLogger: auditLogger,
db: db,
server: &http.Server{
Addr: cfg.ListenAddr,
Handler: mux,
ReadTimeout: cfg.RequestTimeout,
WriteTimeout: cfg.RequestTimeout,
IdleTimeout: cfg.RequestTimeout * 2,
},
}
// Register routes
mux.HandleFunc("/health", s.logRequestMiddleware(s.handleHealth))
mux.HandleFunc("/ready", s.logRequestMiddleware(s.handleReady))
mux.HandleFunc("/live", s.logRequestMiddleware(s.handleLive))
mux.HandleFunc("/manage/projects", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleProjects)))
mux.HandleFunc("/manage/projects/", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleProjectByID)))
mux.HandleFunc("/manage/tokens", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleTokens)))
mux.HandleFunc("/manage/tokens/", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleTokenByID)))
mux.HandleFunc("/manage/audit", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleAuditEvents)))
mux.HandleFunc("/manage/audit/", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleAuditEventByID)))
mux.HandleFunc("/manage/cache/purge", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleCachePurge)))
// Add catch-all handler for unmatched routes to ensure logging
mux.HandleFunc("/", s.logRequestMiddleware(s.handleNotFound))
if cfg.EnableMetrics {
path := cfg.MetricsPath
if path == "" {
path = "/metrics"
}
mux.HandleFunc(path, s.logRequestMiddleware(s.handleMetrics))
}
return s, nil
}
// Start initializes all required components and starts the HTTP server.
// This method blocks until the server is shut down or an error occurs.
//
// It returns an error if the server fails to start or encounters an
// unrecoverable error during operation.
func (s *Server) Start() error {
// Initialize required components
if err := s.initializeComponents(); err != nil {
return fmt.Errorf("failed to initialize components: %w", err)
}
s.logger.Info("Server starting", zap.String("listen_addr", s.config.ListenAddr))
return s.server.ListenAndServe()
}
// initializeComponents sets up all the required components for the server
func (s *Server) initializeComponents() error {
// Initialize API routes from configuration
if err := s.initializeAPIRoutes(); err != nil {
return fmt.Errorf("failed to initialize API routes: %w", err)
}
// Pending: database, logging, admin, and metrics initialization.
// See server_test.go for test stubs covering these responsibilities.
return nil
}
// initializeAPIRoutes sets up the API proxy routes based on configuration
func (s *Server) initializeAPIRoutes() error {
// Load API providers configuration
apiConfig, err := proxy.LoadAPIConfigFromFile(s.config.APIConfigPath)
if err != nil {
// If the config file doesn't exist or has errors, fall back to a default OpenAI configuration
s.logger.Warn("Failed to load API config, using default OpenAI configuration",
zap.String("config_path", s.config.APIConfigPath),
zap.Error(err))
// Create a default API configuration
apiConfig = &proxy.APIConfig{
DefaultAPI: "openai",
APIs: map[string]*proxy.APIProviderConfig{
"openai": {
BaseURL: s.config.OpenAIAPIURL,
AllowedEndpoints: []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/embeddings",
"/v1/models",
"/v1/edits",
"/v1/fine-tunes",
"/v1/files",
"/v1/images/generations",
"/v1/audio/transcriptions",
"/v1/moderations",
},
AllowedMethods: []string{"GET", "POST", "DELETE"},
Timeouts: proxy.TimeoutConfig{
Request: s.config.RequestTimeout,
ResponseHeader: 30 * time.Second,
IdleConnection: 90 * time.Second,
FlushInterval: 100 * time.Millisecond,
},
Connection: proxy.ConnectionConfig{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20,
},
},
},
}
}
// Get proxy configuration for the default API provider
proxyConfig, err := apiConfig.GetProxyConfigForAPI(s.config.DefaultAPIProvider)
if err != nil {
// If specified provider doesn't exist, use the default one
s.logger.Warn("Specified API provider not found, using default", zap.Error(err))
proxyConfig, err = apiConfig.GetProxyConfigForAPI(apiConfig.DefaultAPI)
if err != nil {
return fmt.Errorf("failed to get proxy configuration: %w", err)
}
}
// Apply HTTP cache env overrides (simple toggle + backend selection)
if v := os.Getenv("HTTP_CACHE_ENABLED"); v != "" {
// Parse bool; default to true on invalid for safety
proxyConfig.HTTPCacheEnabled = strings.EqualFold(v, "true") || strings.EqualFold(v, "1") || strings.EqualFold(v, "yes")
} else {
// Default: enabled
proxyConfig.HTTPCacheEnabled = true
}
backend := strings.ToLower(os.Getenv("HTTP_CACHE_BACKEND"))
if backend == "redis" {
// Use REDIS_CACHE_URL if explicitly set, otherwise construct from REDIS_ADDR
url := os.Getenv("REDIS_CACHE_URL")
if url == "" {
// Construct URL from unified REDIS_ADDR config (same as event bus)
addr := os.Getenv("REDIS_ADDR")
if addr == "" {
addr = "localhost:6379"
}
db := os.Getenv("REDIS_DB")
if db == "" {
db = "0"
}
url = fmt.Sprintf("redis://%s/%s", addr, db)
}
proxyConfig.RedisCacheURL = url
if kp := os.Getenv("REDIS_CACHE_KEY_PREFIX"); kp != "" {
proxyConfig.RedisCacheKeyPrefix = kp
}
}
// Use the injected tokenStore and projectStore
// (No more creation of mock stores or test data here)
tokenValidator := token.NewValidator(s.tokenStore)
cachedValidator := token.NewCachedValidator(tokenValidator)
obsCfg := middleware.ObservabilityConfig{Enabled: s.config.ObservabilityEnabled, EventBus: s.eventBus}
proxyHandler, err := proxy.NewTransparentProxyWithAudit(*proxyConfig, cachedValidator, s.projectStore, s.logger, s.auditLogger, obsCfg)
if err != nil {
return fmt.Errorf("failed to initialize proxy: %w", err)
}
s.proxy = proxyHandler
// Initialize cache stats aggregator for per-token cache hit tracking.
// NOTE: Cache stats tracking is only enabled when HTTP caching is enabled (HTTPCacheEnabled=true).
// When caching is disabled, no cache hits occur, so tracking is not needed.
// The Admin UI will show CacheHitCount=0 for all tokens when caching is disabled.
if s.db != nil && proxyConfig.HTTPCacheEnabled {
aggConfig := proxy.CacheStatsAggregatorConfig{
BufferSize: s.config.CacheStatsBufferSize,
FlushInterval: 5 * time.Second,
BatchSize: 100,
}
s.cacheStatsAgg = proxy.NewCacheStatsAggregator(aggConfig, s.db, s.logger)
s.cacheStatsAgg.Start()
proxyHandler.SetCacheStatsAggregator(s.cacheStatsAgg)
s.logger.Info("Cache stats aggregator started", zap.Int("buffer_size", aggConfig.BufferSize))
}
// Register proxy routes
s.server.Handler.(*http.ServeMux).Handle("/v1/", proxyHandler.Handler())
s.logger.Info("Initialized proxy",
zap.String("target_base_url", proxyConfig.TargetBaseURL),
zap.Int("allowed_endpoints", len(proxyConfig.AllowedEndpoints)))
return nil
}
// Shutdown gracefully shuts down the server without interrupting
// active connections. It waits for all connections to complete
// or for the provided context to be canceled, whichever comes first.
//
// The context should typically include a timeout to prevent
// the shutdown from blocking indefinitely.
func (s *Server) Shutdown(ctx context.Context) error {
// Stop cache stats aggregator first to flush pending stats
if s.cacheStatsAgg != nil {
s.logger.Info("Stopping cache stats aggregator")
if err := s.cacheStatsAgg.Stop(ctx); err != nil {
s.logger.Error("failed to stop cache stats aggregator during shutdown", zap.Error(err))
}
}
// Close audit logger to ensure all events are written
if s.auditLogger != nil {
if err := s.auditLogger.Close(); err != nil {
s.logger.Error("failed to close audit logger during shutdown", zap.Error(err))
}
}
return s.server.Shutdown(ctx)
}
// handleHealth is the HTTP handler for the health check endpoint.
// It responds with a JSON payload containing the server status,
// current timestamp, and application version.
//
// This endpoint can be used by load balancers, monitoring tools,
// and container orchestration systems to verify service health.
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
response := HealthResponse{
Status: "ok",
Timestamp: time.Now(),
Version: Version,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
s.logger.Error("Error encoding health response", zap.Error(err))
return
}
// Status code 200 OK is set implicitly when the response is written successfully
}
// handleReady is used for readiness probes.
func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
}
// handleLive is used for liveness probes.
func (s *Server) handleLive(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("alive"))
}
// handleMetrics returns basic runtime metrics in JSON format.
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
m := struct {
UptimeSeconds float64 `json:"uptime_seconds"`
RequestCount int64 `json:"request_count"`
ErrorCount int64 `json:"error_count"`
// Cache metrics (provider-agnostic counters)
CacheHits int64 `json:"cache_hits"`
CacheMisses int64 `json:"cache_misses"`
CacheBypass int64 `json:"cache_bypass"`
CacheStores int64 `json:"cache_stores"`
}{
UptimeSeconds: time.Since(s.metrics.StartTime).Seconds(),
}
if s.proxy != nil {
pm := s.proxy.Metrics()
m.RequestCount = pm.RequestCount
m.ErrorCount = pm.ErrorCount
m.CacheHits = pm.CacheHits
m.CacheMisses = pm.CacheMisses
m.CacheBypass = pm.CacheBypass
m.CacheStores = pm.CacheStores
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(m); err != nil {
s.logger.Error("Failed to encode metrics", zap.Error(err))
http.Error(w, "Failed to encode metrics", http.StatusInternalServerError)
}
}
// managementAuthMiddleware checks the management token in the Authorization header
func (s *Server) managementAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
const prefix = "Bearer "
header := r.Header.Get("Authorization")
if !strings.HasPrefix(header, prefix) || len(header) <= len(prefix) {
http.Error(w, `{"error":"missing or invalid Authorization header"}`, http.StatusUnauthorized)
return
}
token := header[len(prefix):]
if token != s.config.ManagementToken {
http.Error(w, `{"error":"invalid management token"}`, http.StatusUnauthorized)
return
}
next(w, r)
}
}
// GET /manage/projects
// We only register /manage/projects (no trailing slash) for handleProjects. This ensures that both /manage/projects and /manage/projects/ are handled identically, and only /manage/projects/{id} is handled by handleProjectByID. This avoids ambiguity and double handling in Go's http.ServeMux.
func (s *Server) handleProjects(w http.ResponseWriter, r *http.Request) {
s.logger.Debug("handleProjects: START", zap.String("method", r.Method), zap.String("path", r.URL.Path))
// Normalize path: treat /manage/projects/ as /manage/projects
if r.URL.Path == "/manage/projects/" {
r.URL.Path = "/manage/projects"
}
// DEBUG: Log method and headers
for k, v := range r.Header {
if strings.EqualFold(k, "Authorization") {
s.logger.Debug("handleProjects: header", zap.String("key", k), zap.String("value", "******"))
} else {
s.logger.Debug("handleProjects: header", zap.String("key", k), zap.Any("value", v))
}
}
// Mask management token in logs
maskedToken := "******"
if len(s.config.ManagementToken) > 4 {
maskedToken = s.config.ManagementToken[:4] + "******"
}
s.logger.Debug("handleProjects: config.ManagementToken", zap.String("ManagementToken", maskedToken))
if !s.checkManagementAuth(w, r) {
s.logger.Debug("handleProjects: END (auth failed)")
return
}
ctx := r.Context()
requestID := getRequestID(ctx)
switch r.Method {
case http.MethodGet:
s.logger.Info("listing projects", zap.String("request_id", requestID))
s.handleListProjects(w, r.WithContext(ctx))
case http.MethodPost:
s.logger.Info("creating project", zap.String("request_id", requestID))
s.handleCreateProject(w, r.WithContext(ctx))
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
s.logger.Debug("handleProjects: END", zap.String("method", r.Method), zap.String("path", r.URL.Path))
}
func (s *Server) handleListProjects(w http.ResponseWriter, r *http.Request) {
s.logger.Debug("handleListProjects: START")
ctx := r.Context()
requestID := getRequestID(ctx)
projects, err := s.projectStore.ListProjects(ctx)
if err != nil {
s.logger.Error("failed to list projects", zap.Error(err))
// Audit: project list failure
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectList, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithError(err))
http.Error(w, `{"error":"failed to list projects"}`, http.StatusInternalServerError)
s.logger.Debug("handleListProjects: END (error)")
return
}
// Audit: project list success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectList, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithDetail("project_count", len(projects)))
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(projects); err != nil {
s.logger.Error("failed to encode projects response", zap.Error(err))
s.logger.Debug("handleListProjects: END (encode error)")
} else {
s.logger.Debug("handleListProjects: END (success)")
}
}
// POST /manage/projects
func (s *Server) handleCreateProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
var req struct {
Name string `json:"name"`
OpenAIAPIKey string `json:"openai_api_key"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid request body", zap.Error(err), zap.String("request_id", requestID))
// Audit: project creation failure - invalid request
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
if req.Name == "" || req.OpenAIAPIKey == "" {
s.logger.Error("missing required fields", zap.String("name", req.Name), zap.String("openai_api_key", req.OpenAIAPIKey), zap.String("request_id", requestID))
// Audit: project creation failure - missing fields
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("validation_error", "missing required fields").
WithDetail("name_provided", req.Name != "").
WithDetail("api_key_provided", req.OpenAIAPIKey != ""))
http.Error(w, `{"error":"name and openai_api_key are required"}`, http.StatusBadRequest)
return
}
id := uuid.NewString()
now := time.Now().UTC()
project := proxy.Project{
ID: id,
Name: req.Name,
OpenAIAPIKey: req.OpenAIAPIKey,
IsActive: true, // Projects are active by default
CreatedAt: now,
UpdatedAt: now,
}
if err := s.projectStore.CreateProject(ctx, project); err != nil {
s.logger.Error("failed to create project", zap.Error(err), zap.String("name", req.Name), zap.String("request_id", requestID))
// Audit: project creation failure - store error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("project_name", req.Name))
http.Error(w, `{"error":"failed to create project"}`, http.StatusInternalServerError)
return
}
s.logger.Info("project created", zap.String("id", id), zap.String("name", req.Name), zap.String("request_id", requestID))
// Audit: project creation success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(id).
WithDetail("project_name", req.Name))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(project); err != nil {
s.logger.Error("failed to encode project response", zap.Error(err))
}
}
// GET /manage/projects/{id}
func (s *Server) handleGetProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
if id == "" || strings.Contains(id, "/") {
s.logger.Error("invalid project id", zap.String("id", id))
http.Error(w, `{"error":"invalid project id"}`, http.StatusBadRequest)
return
}
project, err := s.projectStore.GetProjectByID(ctx, id)
if err != nil {
s.logger.Error("project not found", zap.String("id", id), zap.Error(err))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(project); err != nil {
s.logger.Error("failed to encode project response", zap.Error(err))
}
}
// PATCH /manage/projects/{id}
func (s *Server) handleUpdateProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
id := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
if id == "" || strings.Contains(id, "/") {
s.logger.Error("invalid project id for update", zap.String("id", id))
// Audit: project update failure - invalid ID
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("validation_error", "invalid project id").
WithDetail("provided_id", id))
http.Error(w, `{"error":"invalid project id"}`, http.StatusBadRequest)
return
}
var req struct {
Name *string `json:"name,omitempty"`
OpenAIAPIKey *string `json:"openai_api_key,omitempty"`
IsActive *bool `json:"is_active,omitempty"`
RevokeTokens *bool `json:"revoke_tokens,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid request body for update", zap.Error(err))
// Audit: project update failure - invalid request body
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
// Validate revoke_tokens usage: it is only valid when explicitly deactivating the project
if req.RevokeTokens != nil {
if req.IsActive == nil || (req.IsActive != nil && *req.IsActive) {
// Audit: invalid field combination
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithDetail("validation_error", "revoke_tokens requires is_active=false"))
http.Error(w, `{"error":"revoke_tokens requires is_active=false"}`, http.StatusBadRequest)
return
}
}
project, err := s.projectStore.GetProjectByID(ctx, id)
if err != nil {
s.logger.Error("project not found for update", zap.String("id", id), zap.Error(err))
// Audit: project update failure - not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("error_type", "project not found"))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
// Track what fields are being updated
var updatedFields []string
if req.Name != nil {
project.Name = *req.Name
updatedFields = append(updatedFields, "name")
}
if req.OpenAIAPIKey != nil {
project.OpenAIAPIKey = *req.OpenAIAPIKey
updatedFields = append(updatedFields, "openai_api_key")
}
// Handle project activation/deactivation
var shouldRevokeTokens bool
if req.IsActive != nil {
if *req.IsActive != project.IsActive {
project.IsActive = *req.IsActive
updatedFields = append(updatedFields, "is_active")
// If deactivating project, set deactivated timestamp
if !*req.IsActive {
now := time.Now().UTC()
project.DeactivatedAt = &now
updatedFields = append(updatedFields, "deactivated_at")
// Check if tokens should be revoked when deactivating
if req.RevokeTokens != nil && *req.RevokeTokens {
shouldRevokeTokens = true
}
} else {
// Reactivating project, clear deactivated timestamp
project.DeactivatedAt = nil
}
}
}
project.UpdatedAt = time.Now().UTC()
if err := s.projectStore.UpdateProject(ctx, project); err != nil {
s.logger.Error("failed to update project", zap.String("id", id), zap.Error(err))
// Audit: project update failure - store error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("updated_fields", updatedFields))
http.Error(w, `{"error":"failed to update project"}`, http.StatusInternalServerError)
return
}
// Revoke project tokens if requested
var revokedTokensCount int
if shouldRevokeTokens {
tokens, err := s.tokenStore.GetTokensByProjectID(ctx, id)
if err != nil {
s.logger.Warn("failed to get project tokens for revocation", zap.String("project_id", id), zap.Error(err))
} else {
// Revoke all active tokens for this project
for _, token := range tokens {
if token.IsActive {
token.IsActive = false
token.DeactivatedAt = nowPtrUTC()
if err := s.tokenStore.UpdateToken(ctx, token); err != nil {
s.logger.Warn("failed to revoke token during project deactivation",
zap.String("token_id", token.Token),
zap.String("project_id", id),
zap.Error(err))
} else {
revokedTokensCount++
}
}
}
}
}
s.logger.Info("project updated", zap.String("id", id), zap.Strings("updated_fields", updatedFields))
// Audit: project update success
auditEvent := s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(id).
WithDetail("updated_fields", updatedFields).
WithDetail("project_name", project.Name)
if shouldRevokeTokens {
auditEvent.WithDetail("tokens_revoked", revokedTokensCount)
}
_ = s.auditLogger.Log(auditEvent)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(project); err != nil {
s.logger.Error("failed to encode project response", zap.Error(err))
}
}
// DELETE /manage/projects/{id}
// DELETE /manage/projects/{id} - Returns 405 Method Not Allowed
func (s *Server) handleDeleteProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
id := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
// Audit: project delete attempt - method not allowed
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectDelete, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithDetail("error_type", "method not allowed").
WithDetail("reason", "project deletion is not permitted"))
w.Header().Set("Allow", "GET, PATCH")
http.Error(w, `{"error":"method not allowed","message":"project deletion is not permitted"}`, http.StatusMethodNotAllowed)
}
// POST /manage/projects/{id}/tokens/revoke - Bulk revoke all tokens for a project
func (s *Server) handleBulkRevokeProjectTokens(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Extract project ID from path
pathSuffix := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
if !strings.HasSuffix(pathSuffix, "/tokens/revoke") {
s.logger.Error("invalid bulk revoke path", zap.String("path", r.URL.Path), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid path"}`, http.StatusBadRequest)
return
}
projectID := strings.TrimSuffix(pathSuffix, "/tokens/revoke")
if projectID == "" {
s.logger.Error("invalid project ID in bulk revoke path", zap.String("path", r.URL.Path), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid path"}`, http.StatusBadRequest)
return
}
// Verify project exists
_, err := s.projectStore.GetProjectByID(ctx, projectID)
if err != nil {
s.logger.Error("project not found for bulk token revoke", zap.String("project_id", projectID), zap.Error(err), zap.String("request_id", requestID))
// Audit: bulk revoke failure - project not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevokeBatch, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(projectID).
WithError(err).
WithDetail("error_type", "project not found"))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
// Get all tokens for the project
tokens, err := s.tokenStore.GetTokensByProjectID(ctx, projectID)
if err != nil {
s.logger.Error("failed to get tokens for bulk revoke", zap.String("project_id", projectID), zap.Error(err), zap.String("request_id", requestID))
// Audit: bulk revoke failure - failed to get tokens
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevokeBatch, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(projectID).
WithError(err).
WithDetail("error_type", "failed to get tokens"))
http.Error(w, `{"error":"failed to get project tokens"}`, http.StatusInternalServerError)
return
}
// Count and revoke active tokens
var revokedCount, alreadyRevokedCount int
var failedRevocations []string
for _, token := range tokens {
if !token.IsActive {
alreadyRevokedCount++
continue
}
// Revoke the token
token.IsActive = false
token.DeactivatedAt = nowPtrUTC()
if err := s.tokenStore.UpdateToken(ctx, token); err != nil {
s.logger.Warn("failed to revoke individual token during bulk revoke",
zap.String("token_id", token.Token),
zap.String("project_id", projectID),
zap.Error(err))
failedRevocations = append(failedRevocations, token.Token)
} else {
revokedCount++
}
}
s.logger.Info("bulk token revocation completed",
zap.String("project_id", projectID),
zap.Int("revoked_count", revokedCount),
zap.Int("already_revoked_count", alreadyRevokedCount),
zap.Int("failed_count", len(failedRevocations)),
zap.String("request_id", requestID),
)
// Audit: bulk revoke success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevokeBatch, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(projectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithDetail("total_tokens", len(tokens)).
WithDetail("revoked_count", revokedCount).
WithDetail("already_revoked_count", alreadyRevokedCount).
WithDetail("failed_count", len(failedRevocations)))
// Return summary response
response := map[string]interface{}{
"revoked_count": revokedCount,
"already_revoked_count": alreadyRevokedCount,
"total_tokens": len(tokens),
}
if len(failedRevocations) > 0 {
response["failed_count"] = len(failedRevocations)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode bulk revoke response", zap.Error(err))
}
}
// Add this helper to *Server
func (s *Server) checkManagementAuth(w http.ResponseWriter, r *http.Request) bool {
const prefix = "Bearer "
header := r.Header.Get("Authorization")
maskedHeader := header
if len(header) > 10 {
maskedHeader = header[:10] + "..."
}
s.logger.Debug("checkManagementAuth: header", zap.String("header", maskedHeader))
if !strings.HasPrefix(header, prefix) || len(header) <= len(prefix) {
s.logger.Debug("checkManagementAuth: missing or invalid prefix")
http.Error(w, `{"error":"missing or invalid Authorization header"}`, http.StatusUnauthorized)
return false
}
token := header[len(prefix):]
maskedToken := "******"
if len(s.config.ManagementToken) > 4 {
maskedToken = s.config.ManagementToken[:4] + "******"
}
s.logger.Debug("checkManagementAuth: token compare", zap.String("token", token), zap.String("expected", maskedToken))
if token != s.config.ManagementToken {
s.logger.Debug("checkManagementAuth: token mismatch")
http.Error(w, `{"error":"invalid management token"}`, http.StatusUnauthorized)
return false
}
s.logger.Debug("checkManagementAuth: token match")
return true
}
// Add the handler function
func (s *Server) handleProjectByID(w http.ResponseWriter, r *http.Request) {
// Check if this is a bulk token revoke request
if strings.HasSuffix(r.URL.Path, "/tokens/revoke") && r.Method == http.MethodPost {
s.handleBulkRevokeProjectTokens(w, r)
return
}
switch r.Method {
case http.MethodGet:
s.handleGetProject(w, r)
case http.MethodPatch:
s.handleUpdateProject(w, r)
case http.MethodDelete:
s.handleDeleteProject(w, r)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// Handler for /manage/tokens (POST: create, GET: list)
func (s *Server) handleTokens(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
switch r.Method {
case http.MethodPost:
var req struct {
ProjectID string `json:"project_id"`
DurationMinutes int `json:"duration_minutes"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid token create request body", zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - invalid request
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
var duration time.Duration
if req.DurationMinutes > 0 {
if req.DurationMinutes > maxDurationMinutes {
s.logger.Error("duration_minutes exceeds maximum allowed", zap.Int("duration_minutes", req.DurationMinutes), zap.String("request_id", requestID))
// Audit: token creation failure - duration too long
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithDetail("validation_error", "duration exceeds maximum").
WithDetail("requested_duration_minutes", req.DurationMinutes).
WithDetail("max_duration_minutes", maxDurationMinutes))
http.Error(w, `{"error":"duration_minutes exceeds maximum allowed"}`, http.StatusBadRequest)
return
}
duration = time.Duration(req.DurationMinutes) * time.Minute
} else {
s.logger.Error("missing required fields for token create", zap.String("project_id", req.ProjectID), zap.Int("duration_minutes", req.DurationMinutes), zap.String("request_id", requestID))
// Audit: token creation failure - missing duration
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithDetail("validation_error", "missing duration_minutes"))
http.Error(w, `{"error":"project_id and duration_minutes are required"}`, http.StatusBadRequest)
return
}
if req.ProjectID == "" {
s.logger.Error("missing project_id for token create", zap.String("request_id", requestID))
// Audit: token creation failure - missing project ID
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("validation_error", "missing project_id"))
http.Error(w, `{"error":"project_id is required"}`, http.StatusBadRequest)
return
}
// Check project exists and is active
project, err := s.projectStore.GetProjectByID(ctx, req.ProjectID)
if err != nil {
s.logger.Error("project not found for token create", zap.String("project_id", req.ProjectID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - project not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithError(err).
WithDetail("error_type", "project not found"))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
// Check if project is active
if !project.IsActive {
s.logger.Warn("token creation denied for inactive project", zap.String("project_id", req.ProjectID), zap.String("request_id", requestID))
// Audit: token creation failure - project inactive
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithDetail("error_type", "project_inactive").
WithDetail("reason", "cannot create tokens for inactive projects"))
http.Error(w, `{"error":"cannot create tokens for inactive projects","code":"project_inactive"}`, http.StatusForbidden)
return
}
// Generate token
tokenStr, expiresAt, _, err := token.NewTokenGenerator().GenerateWithOptions(duration, nil)
if err != nil {
s.logger.Error("failed to generate token", zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - generation error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithError(err).
WithDetail("error_type", "token generation failed"))
http.Error(w, `{"error":"failed to generate token"}`, http.StatusInternalServerError)
return
}
now := time.Now().UTC()
dbToken := token.TokenData{
Token: tokenStr,
ProjectID: req.ProjectID,
ExpiresAt: expiresAt,
IsActive: true,
RequestCount: 0,
CreatedAt: now,
}
if err := s.tokenStore.CreateToken(ctx, dbToken); err != nil {
s.logger.Error("failed to store token", zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - storage error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithTokenID(tokenStr).
WithError(err).
WithDetail("error_type", "storage failed"))
http.Error(w, `{"error":"failed to store token"}`, http.StatusInternalServerError)
return
}
s.logger.Info("token created",
zap.String("token", token.ObfuscateToken(tokenStr)),
zap.String("project_id", req.ProjectID),
zap.String("request_id", requestID),
)
// Audit: token creation success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(req.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithTokenID(tokenStr).
WithDetail("duration_minutes", req.DurationMinutes).
WithDetail("expires_at", expiresAt.Format(time.RFC3339)))
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(map[string]interface{}{
"token": tokenStr,
"expires_at": expiresAt,
}); err != nil {
s.logger.Error("failed to encode token response", zap.Error(err))
}
case http.MethodGet:
projectID := r.URL.Query().Get("projectId")
var tokens []token.TokenData
var err error
if projectID != "" {
tokens, err = s.tokenStore.GetTokensByProjectID(ctx, projectID)
} else {
tokens, err = s.tokenStore.ListTokens(ctx)
}
if err != nil {
s.logger.Error("failed to list tokens", zap.Error(err))
// Audit: token list failure
auditEvent := s.auditEvent(audit.ActionTokenList, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithError(err)
if projectID != "" {
auditEvent.WithProjectID(projectID)
}
_ = s.auditLogger.Log(auditEvent)
http.Error(w, `{"error":"failed to list tokens"}`, http.StatusInternalServerError)
return
}
s.logger.Info("tokens listed", zap.Int("count", len(tokens)))
// Audit: token list success
auditEvent := s.auditEvent(audit.ActionTokenList, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithDetail("token_count", len(tokens))
if projectID != "" {
auditEvent.WithProjectID(projectID).WithDetail("filtered_by_project", true)
}
_ = s.auditLogger.Log(auditEvent)
w.Header().Set("Content-Type", "application/json")
// Create sanitized response without actual token values
sanitizedTokens := make([]TokenListResponse, len(tokens))
for i, token := range tokens {
sanitizedTokens[i] = TokenListResponse{
TokenID: token.Token,
ProjectID: token.ProjectID,
ExpiresAt: token.ExpiresAt,
IsActive: token.IsActive,
RequestCount: token.RequestCount,
MaxRequests: token.MaxRequests,
CreatedAt: token.CreatedAt,
LastUsedAt: token.LastUsedAt,
CacheHitCount: token.CacheHitCount,
}
}
if err := json.NewEncoder(w).Encode(sanitizedTokens); err != nil {
s.logger.Error("failed to encode tokens response", zap.Error(err))
}
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// Handler for /manage/tokens/{id} (GET: retrieve, PATCH: update, DELETE: revoke)
func (s *Server) handleTokenByID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Extract token ID from path
tokenID := strings.TrimPrefix(r.URL.Path, "/manage/tokens/")
if tokenID == "" || tokenID == "/" {
s.logger.Error("invalid token ID in path", zap.String("path", r.URL.Path), zap.String("request_id", requestID))
http.Error(w, `{"error":"token ID is required"}`, http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodGet:
s.handleGetToken(w, r, tokenID)
case http.MethodPatch:
s.handleUpdateToken(w, r, tokenID)
case http.MethodDelete:
s.handleRevokeToken(w, r, tokenID)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// GET /manage/tokens/{id}
func (s *Server) handleGetToken(w http.ResponseWriter, r *http.Request, tokenID string) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Get token from store
tokenData, err := s.tokenStore.GetTokenByID(ctx, tokenID)
if err != nil {
s.logger.Error("failed to get token", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token get failure
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRead, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("error_type", "token not found"))
http.Error(w, `{"error":"token not found"}`, http.StatusNotFound)
return
}
// Audit: token get success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRead, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path))
// Create sanitized response without the actual token value
response := TokenListResponse{
TokenID: tokenID,
ProjectID: tokenData.ProjectID,
ExpiresAt: tokenData.ExpiresAt,
IsActive: tokenData.IsActive,
RequestCount: tokenData.RequestCount,
MaxRequests: tokenData.MaxRequests,
CreatedAt: tokenData.CreatedAt,
LastUsedAt: tokenData.LastUsedAt,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode token response", zap.Error(err))
}
}
// PATCH /manage/tokens/{id}
func (s *Server) handleUpdateToken(w http.ResponseWriter, r *http.Request, tokenID string) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Parse request body
var req struct {
IsActive *bool `json:"is_active,omitempty"`
MaxRequests *int `json:"max_requests,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid token update request body", zap.Error(err), zap.String("request_id", requestID))
// Audit: token update failure - invalid request
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
// Get existing token
tokenData, err := s.tokenStore.GetTokenByID(ctx, tokenID)
if err != nil {
s.logger.Error("failed to get token for update", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token update failure - token not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("error_type", "token not found"))
http.Error(w, `{"error":"token not found"}`, http.StatusNotFound)
return
}
// Update fields if provided
updated := false
if req.IsActive != nil {
tokenData.IsActive = *req.IsActive
updated = true
}
if req.MaxRequests != nil {
tokenData.MaxRequests = req.MaxRequests
updated = true
}
if !updated {
s.logger.Error("no fields to update", zap.String("token_id", tokenID), zap.String("request_id", requestID))
// Audit: token update failure - no fields
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithDetail("validation_error", "no fields to update"))
http.Error(w, `{"error":"no fields to update"}`, http.StatusBadRequest)
return
}
// Update token in store
if err := s.tokenStore.UpdateToken(ctx, tokenData); err != nil {
s.logger.Error("failed to update token", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token update failure - storage error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithError(err).
WithDetail("error_type", "storage failed"))
http.Error(w, `{"error":"failed to update token"}`, http.StatusInternalServerError)
return
}
s.logger.Info("token updated",
zap.String("token_id", tokenID),
zap.String("project_id", tokenData.ProjectID),
zap.String("request_id", requestID),
)
// Audit: token update success
auditEvent := s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path)
if req.IsActive != nil {
auditEvent.WithDetail("updated_is_active", *req.IsActive)
}
if req.MaxRequests != nil {
auditEvent.WithDetail("updated_max_requests", *req.MaxRequests)
}
_ = s.auditLogger.Log(auditEvent)
// Return updated token (sanitized)
response := TokenListResponse{
TokenID: tokenID,
ProjectID: tokenData.ProjectID,
ExpiresAt: tokenData.ExpiresAt,
IsActive: tokenData.IsActive,
RequestCount: tokenData.RequestCount,
MaxRequests: tokenData.MaxRequests,
CreatedAt: tokenData.CreatedAt,
LastUsedAt: tokenData.LastUsedAt,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode updated token response", zap.Error(err))
}
}
// DELETE /manage/tokens/{id} (revoke token)
func (s *Server) handleRevokeToken(w http.ResponseWriter, r *http.Request, tokenID string) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Get existing token first to verify it exists and get project ID
tokenData, err := s.tokenStore.GetTokenByID(ctx, tokenID)
if err != nil {
s.logger.Error("failed to get token for revocation", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token revoke failure - token not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("error_type", "token not found"))
http.Error(w, `{"error":"token not found"}`, http.StatusNotFound)
return
}
// Check if already inactive
if !tokenData.IsActive {
s.logger.Warn("token already revoked", zap.String("token_id", tokenID), zap.String("request_id", requestID))
// Audit: token revoke success (idempotent)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithDetail("already_revoked", true))
w.WriteHeader(http.StatusNoContent)
return
}
// Revoke token by setting is_active to false
tokenData.IsActive = false
tokenData.DeactivatedAt = nowPtrUTC()
if err := s.tokenStore.UpdateToken(ctx, tokenData); err != nil {
s.logger.Error("failed to revoke token", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token revoke failure - storage error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithError(err).
WithDetail("error_type", "storage failed"))
http.Error(w, `{"error":"failed to revoke token"}`, http.StatusInternalServerError)
return
}
s.logger.Info("token revoked",
zap.String("token_id", tokenID),
zap.String("project_id", tokenData.ProjectID),
zap.String("request_id", requestID),
)
// Audit: token revoke success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path))
w.WriteHeader(http.StatusNoContent)
}
func getRequestID(ctx context.Context) string {
if requestID, ok := logging.GetRequestID(ctx); ok && requestID != "" {
return requestID
}
return uuid.New().String()
}
// nowPtrUTC returns the current UTC time as a *time.Time convenience helper.
func nowPtrUTC() *time.Time {
t := time.Now().UTC()
return &t
}
// parseInt parses a string to an integer with a default value
func parseInt(s string, defaultValue int) int {
if s == "" {
return defaultValue
}
var result int
if _, err := fmt.Sscanf(s, "%d", &result); err != nil {
return defaultValue
}
return result
}
// logRequestMiddleware logs all incoming requests with timing information using structured logging
func (s *Server) logRequestMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// Get or generate request ID from header
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
// Get or generate correlation ID from header
correlationID := r.Header.Get("X-Correlation-ID")
if correlationID == "" {
correlationID = uuid.New().String()
}
// Add to context using our new context helpers
ctx := logging.WithRequestID(r.Context(), requestID)
ctx = logging.WithCorrelationID(ctx, correlationID)
// Set response headers
w.Header().Set("X-Request-ID", requestID)
w.Header().Set("X-Correlation-ID", correlationID)
// Create a response writer that captures status code
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Get client IP
clientIP := s.getClientIP(r)
// Create logger with request context
reqLogger := logging.WithRequestContext(ctx, s.logger)
reqLogger = logging.WithCorrelationContext(ctx, reqLogger)
reqLogger.Info("request started",
logging.ClientIP(clientIP),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("user_agent", r.UserAgent()),
)
// Call the next handler
next(rw, r.WithContext(ctx))
duration := time.Since(startTime)
durationMs := int(duration.Milliseconds())
// Log completion with canonical fields
if rw.statusCode >= 500 {
reqLogger.Error("request completed with server error",
logging.RequestFields(requestID, r.Method, r.URL.Path, rw.statusCode, durationMs)...,
)
} else {
reqLogger.Info("request completed",
logging.RequestFields(requestID, r.Method, r.URL.Path, rw.statusCode, durationMs)...,
)
}
}
}
// getClientIP extracts the client IP address from the request
func (s *Server) getClientIP(r *http.Request) string {
// Check for X-Forwarded-For header first (in case of proxy)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check for X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// Add Flush forwarding for streaming support
func (rw *responseWriter) Flush() {
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// handleNotFound is a catch-all handler for unmatched routes
func (s *Server) handleNotFound(w http.ResponseWriter, r *http.Request) {
s.logger.Info("route not found",
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("remote_addr", r.RemoteAddr),
)
http.NotFound(w, r)
}
// EventBus returns the event bus used by the server (may be nil if observability is disabled)
func (s *Server) EventBus() eventbus.EventBus {
return s.eventBus
}
// auditEvent creates a new audit event with common fields filled from the HTTP request
func (s *Server) auditEvent(action string, actor string, result audit.ResultType, r *http.Request, requestID string) *audit.Event {
clientIP := s.getClientIP(r)
// Prefer forwarded UA and referer from Admin UI if present
forwardedUA := r.Header.Get("X-Forwarded-User-Agent")
userAgent := r.UserAgent()
if forwardedUA != "" {
userAgent = forwardedUA
}
forwardedRef := r.Header.Get("X-Forwarded-Referer")
ev := audit.NewEvent(action, actor, result).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithClientIP(clientIP).
WithUserAgent(userAgent)
if forwardedRef != "" {
ev = ev.WithDetail("referer", forwardedRef)
}
if r.Header.Get("X-Admin-Origin") == "1" {
ev = ev.WithDetail("origin", "admin-ui")
}
return ev
}
// Handler for /manage/audit (GET: list)
func (s *Server) handleAuditEvents(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
// Only proceed if we have database access
if s.db == nil {
s.logger.Error("audit events requested but database not available", zap.String("request_id", requestID))
http.Error(w, `{"error":"audit events not available"}`, http.StatusServiceUnavailable)
return
}
// Parse query parameters for filtering
query := r.URL.Query()
filters := database.AuditEventFilters{
Action: query.Get("action"),
ClientIP: query.Get("client_ip"),
ProjectID: query.Get("project_id"),
Outcome: query.Get("outcome"),
Actor: query.Get("actor"),
RequestID: query.Get("request_id"),
CorrelationID: query.Get("correlation_id"),
Method: query.Get("method"),
Path: query.Get("path"),
Search: query.Get("search"),
}
// Parse time filters
if startTime := query.Get("start_time"); startTime != "" {
filters.StartTime = &startTime
}
if endTime := query.Get("end_time"); endTime != "" {
filters.EndTime = &endTime
}
// Parse pagination
page := parseInt(query.Get("page"), 1)
pageSize := parseInt(query.Get("page_size"), 20)
if pageSize > 100 {
pageSize = 100 // Limit page size
}
filters.Limit = pageSize
filters.Offset = (page - 1) * pageSize
// Get audit events
events, err := s.db.ListAuditEvents(ctx, filters)
if err != nil {
s.logger.Error("failed to list audit events", zap.Error(err), zap.String("request_id", requestID))
http.Error(w, `{"error":"failed to list audit events"}`, http.StatusInternalServerError)
return
}
// Get total count for pagination
totalCount, err := s.db.CountAuditEvents(ctx, filters)
if err != nil {
s.logger.Error("failed to count audit events", zap.Error(err), zap.String("request_id", requestID))
http.Error(w, `{"error":"failed to count audit events"}`, http.StatusInternalServerError)
return
}
// Calculate pagination info
totalPages := (totalCount + pageSize - 1) / pageSize
hasNext := page < totalPages
hasPrev := page > 1
response := map[string]interface{}{
"events": events,
"pagination": map[string]interface{}{
"page": page,
"page_size": pageSize,
"total_count": totalCount,
"total_pages": totalPages,
"has_next": hasNext,
"has_prev": hasPrev,
},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode audit events response", zap.Error(err), zap.String("request_id", requestID))
}
// Audit: successful audit events listing
_ = s.auditLogger.Log(s.auditEvent(audit.ActionAuditList, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithDetail("events_count", len(events)).
WithDetail("page", page).
WithDetail("page_size", pageSize).
WithDetail("total_count", totalCount))
}
// Handler for /manage/audit/{id} (GET: show)
func (s *Server) handleAuditEventByID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
// Only proceed if we have database access
if s.db == nil {
s.logger.Error("audit event requested but database not available", zap.String("request_id", requestID))
http.Error(w, `{"error":"audit event not available"}`, http.StatusServiceUnavailable)
return
}
// Extract audit event ID from path
id := strings.TrimPrefix(r.URL.Path, "/manage/audit/")
if id == "" {
http.Error(w, `{"error":"audit event ID is required"}`, http.StatusBadRequest)
return
}
// Get specific audit event by ID
event, err := s.db.GetAuditEventByID(ctx, id)
if err != nil {
if strings.Contains(err.Error(), "not found") {
http.Error(w, `{"error":"audit event not found"}`, http.StatusNotFound)
} else {
s.logger.Error("failed to get audit event", zap.Error(err), zap.String("audit_id", id), zap.String("request_id", requestID))
http.Error(w, `{"error":"failed to get audit event"}`, http.StatusInternalServerError)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(event); err != nil {
s.logger.Error("failed to encode audit event response", zap.Error(err), zap.String("audit_id", id), zap.String("request_id", requestID))
}
// Audit: successful audit event retrieval
_ = s.auditLogger.Log(s.auditEvent(audit.ActionAuditShow, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithDetail("audit_event_id", id))
}
// CachePurgeRequest represents the request body for cache purge operations
type CachePurgeRequest struct {
Method string `json:"method" binding:"required"`
URL string `json:"url" binding:"required"`
Prefix string `json:"prefix,omitempty"`
}
// CachePurgeResponse represents the response body for cache purge operations
type CachePurgeResponse struct {
Deleted interface{} `json:"deleted"` // bool for exact purge, int for prefix purge
}
// Handler for POST /manage/cache/purge
func (s *Server) handleCachePurge(w http.ResponseWriter, r *http.Request) {
requestID := getRequestID(r.Context())
if r.Method != http.MethodPost {
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
return
}
// Check if proxy and cache are available
if s.proxy == nil {
s.logger.Error("proxy not initialized", zap.String("request_id", requestID))
http.Error(w, `{"error":"proxy not available"}`, http.StatusInternalServerError)
return
}
cache := s.proxy.Cache()
if cache == nil {
s.logger.Warn("cache purge attempted but caching is disabled", zap.String("request_id", requestID))
http.Error(w, `{"error":"caching is disabled"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "caching_disabled"))
return
}
var req CachePurgeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Warn("invalid JSON in cache purge request", zap.Error(err), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid JSON"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "invalid_json"))
return
}
// Validate required fields
if req.Method == "" || req.URL == "" {
s.logger.Warn("missing required fields in cache purge request",
zap.String("method", req.Method), zap.String("url", req.URL), zap.String("request_id", requestID))
http.Error(w, `{"error":"method and url are required"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "missing_fields"))
return
}
var response CachePurgeResponse
var auditDetails map[string]interface{}
if req.Prefix != "" {
// Prefix purge
deleted := cache.PurgePrefix(req.Prefix)
response.Deleted = deleted
auditDetails = map[string]interface{}{
"purge_type": "prefix",
"prefix": req.Prefix,
"deleted": deleted,
}
s.logger.Info("cache prefix purge completed",
zap.String("prefix", req.Prefix), zap.Int("deleted", deleted), zap.String("request_id", requestID))
} else {
// Exact key purge - need to compute cache key from method and URL
// Create a mock request to generate the cache key
mockURL, err := url.Parse(req.URL)
if err != nil {
s.logger.Warn("invalid URL in cache purge request", zap.Error(err), zap.String("url", req.URL), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid URL"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "invalid_url"))
return
}
mockReq := &http.Request{
Method: req.Method,
URL: mockURL,
Header: make(http.Header),
}
// Generate cache key using existing helper
cacheKey := proxy.CacheKeyFromRequest(mockReq)
deleted := cache.Purge(cacheKey)
response.Deleted = deleted
auditDetails = map[string]interface{}{
"purge_type": "exact",
"method": req.Method,
"url": req.URL,
"cache_key": cacheKey,
"deleted": deleted,
}
s.logger.Info("cache exact purge completed",
zap.String("method", req.Method), zap.String("url", req.URL),
zap.String("cache_key", cacheKey), zap.Bool("deleted", deleted), zap.String("request_id", requestID))
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode cache purge response", zap.Error(err), zap.String("request_id", requestID))
return
}
// Audit: successful cache purge
auditEvent := s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultSuccess, r, requestID)
for k, v := range auditDetails {
auditEvent = auditEvent.WithDetail(k, v)
}
_ = s.auditLogger.Log(auditEvent)
}
// Package setup provides configuration setup and management utilities.
package setup
import (
"fmt"
"os"
"path/filepath"
"github.com/sofatutor/llm-proxy/internal/utils"
)
// SetupConfig holds configuration parameters for setup
type SetupConfig struct {
ConfigPath string
OpenAIAPIKey string
ManagementToken string
DatabasePath string
ListenAddr string
}
// ValidateConfig validates the setup configuration
func (sc *SetupConfig) ValidateConfig() error {
if sc.OpenAIAPIKey == "" {
return fmt.Errorf("OpenAI API key is required")
}
if sc.ConfigPath == "" {
return fmt.Errorf("config path is required")
}
if sc.DatabasePath == "" {
return fmt.Errorf("database path is required")
}
if sc.ListenAddr == "" {
return fmt.Errorf("listen address is required")
}
return nil
}
// GenerateManagementToken generates a management token if not provided
func (sc *SetupConfig) GenerateManagementToken() error {
if sc.ManagementToken == "" {
token, err := utils.GenerateSecureToken(16)
if err != nil {
return fmt.Errorf("failed to generate management token: %w", err)
}
sc.ManagementToken = token
}
return nil
}
// WriteConfigFile writes the configuration to a file
func (sc *SetupConfig) WriteConfigFile() error {
if err := sc.ValidateConfig(); err != nil {
return err
}
// Ensure directory exists
dir := filepath.Dir(sc.ConfigPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Ensure database directory exists
dbDir := filepath.Dir(sc.DatabasePath)
if err := os.MkdirAll(dbDir, 0755); err != nil {
return fmt.Errorf("failed to create database directory: %w", err)
}
content := fmt.Sprintf(`# LLM Proxy Configuration
OPENAI_API_KEY=%s
MANAGEMENT_TOKEN=%s
DATABASE_PATH=%s
LISTEN_ADDR=%s
LOG_LEVEL=info
`, sc.OpenAIAPIKey, sc.ManagementToken, sc.DatabasePath, sc.ListenAddr)
if err := os.WriteFile(sc.ConfigPath, []byte(content), 0600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
return nil
}
// RunNonInteractiveSetup performs non-interactive setup with the given configuration
func RunNonInteractiveSetup(sc *SetupConfig) error {
if err := sc.GenerateManagementToken(); err != nil {
return fmt.Errorf("failed to generate management token: %w", err)
}
if err := sc.WriteConfigFile(); err != nil {
return fmt.Errorf("failed to write configuration: %w", err)
}
return nil
}
package token
import (
"container/heap"
"context"
"fmt"
"sync"
"time"
)
// CacheEntry represents a cached token data with expiration
type CacheEntry struct {
Data TokenData
ValidUntil time.Time
}
// cacheEntry is a heap entry for eviction, with index for fast updates
// and tokenID for lookup.
type cacheEntry struct {
tokenID string
validUntil time.Time
insertedAt int64 // strictly increasing for FIFO eviction
index int // index in the heap
}
type cacheEntryHeap []*cacheEntry
func (h cacheEntryHeap) Len() int { return len(h) }
func (h cacheEntryHeap) Less(i, j int) bool { return h[i].insertedAt < h[j].insertedAt } // FIFO eviction
func (h cacheEntryHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
h[i].index = i
h[j].index = j
}
func (h *cacheEntryHeap) Push(x interface{}) {
entry := x.(*cacheEntry)
entry.index = len(*h)
*h = append(*h, entry)
}
func (h *cacheEntryHeap) Pop() interface{} {
old := *h
n := len(old)
item := old[n-1]
item.index = -1 // for safety
*h = old[0 : n-1]
return item
}
// CachedValidator wraps a TokenValidator with caching
type CachedValidator struct {
validator TokenValidator
cache map[string]CacheEntry
cacheMutex sync.RWMutex
cacheTTL time.Duration
maxCacheSize int
// Min-heap for eviction
heap cacheEntryHeap
heapIndex map[string]*cacheEntry // tokenID -> *cacheEntry
insertCounter int64 // strictly increasing counter for insertedAt
// For cache stats
hits int
misses int
evictions int
statsMutex sync.Mutex
}
// CacheOptions defines the options for the token cache
type CacheOptions struct {
// Time-to-live for cache entries (default: 5 minutes)
TTL time.Duration
// Maximum size of the cache (default: 1000)
MaxSize int
// Whether to enable automatic cache cleanup (default: true)
EnableCleanup bool
// Interval for cache cleanup (default: 1 minute)
CleanupInterval time.Duration
}
// DefaultCacheOptions returns the default cache options
func DefaultCacheOptions() CacheOptions {
return CacheOptions{
TTL: 5 * time.Minute,
MaxSize: 1000,
EnableCleanup: true,
CleanupInterval: 1 * time.Minute,
}
}
// NewCachedValidator creates a new validator with caching
func NewCachedValidator(validator TokenValidator, options ...CacheOptions) *CachedValidator {
opts := DefaultCacheOptions()
if len(options) > 0 {
opts = options[0]
}
cv := &CachedValidator{
validator: validator,
cache: make(map[string]CacheEntry),
cacheTTL: opts.TTL,
maxCacheSize: opts.MaxSize,
heap: make(cacheEntryHeap, 0, opts.MaxSize),
heapIndex: make(map[string]*cacheEntry, opts.MaxSize),
}
// Start cache cleanup if enabled
if opts.EnableCleanup {
go cv.startCleanup(opts.CleanupInterval)
}
return cv
}
// ValidateToken validates a token using the cache when possible
func (cv *CachedValidator) ValidateToken(ctx context.Context, tokenID string) (string, error) {
// Check cache first
projectID, found := cv.checkCache(tokenID)
if found {
return projectID, nil
}
// Cache miss, validate using the underlying validator
projectID, err := cv.validator.ValidateToken(ctx, tokenID)
if err != nil {
return "", err
}
// Cache the successful validation
cv.cacheToken(ctx, tokenID)
return projectID, nil
}
// ValidateTokenWithTracking validates a token and tracks usage (bypasses cache for tracking)
func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) {
// Always use the underlying validator for tracking requests
projectID, err := cv.validator.ValidateTokenWithTracking(ctx, tokenID)
if err != nil {
return "", err
}
// Update the cache if the token is already cached
cv.invalidateCache(tokenID)
return projectID, nil
}
// checkCache checks if a token is in the cache and still valid
func (cv *CachedValidator) checkCache(tokenID string) (string, bool) {
cv.cacheMutex.RLock()
entry, found := cv.cache[tokenID]
cv.cacheMutex.RUnlock()
// Not in cache
if !found {
cv.statsMutex.Lock()
cv.misses++
cv.statsMutex.Unlock()
return "", false
}
// In cache but expired
now := time.Now()
if now.After(entry.ValidUntil) {
cv.cacheMutex.Lock()
delete(cv.cache, tokenID)
cv.cacheMutex.Unlock()
cv.statsMutex.Lock()
cv.misses++
cv.evictions++
cv.statsMutex.Unlock()
return "", false
}
// In cache and valid
cv.statsMutex.Lock()
cv.hits++
cv.statsMutex.Unlock()
return entry.Data.ProjectID, true
}
// cacheToken retrieves and caches a token
func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) {
cv.cacheMutex.Lock()
defer cv.cacheMutex.Unlock()
standardValidator, ok := cv.validator.(*StandardValidator)
if !ok {
return
}
tokenData, err := standardValidator.store.GetTokenByID(ctx, tokenID)
if err != nil {
return
}
if !tokenData.IsValid() {
return
}
validUntil := time.Now().Add(cv.cacheTTL)
insertedAt := cv.insertCounter
cv.insertCounter++
cv.cache[tokenID] = CacheEntry{
Data: tokenData,
ValidUntil: validUntil,
}
// Remove old heap entry if present
if oldEntry, ok := cv.heapIndex[tokenID]; ok {
idx := oldEntry.index
heap.Remove(&cv.heap, idx)
delete(cv.heapIndex, tokenID)
}
entry := &cacheEntry{
tokenID: tokenID,
validUntil: validUntil,
insertedAt: insertedAt,
}
heap.Push(&cv.heap, entry)
cv.heapIndex[tokenID] = entry
// Evict if over capacity
if cv.maxCacheSize > 0 && len(cv.cache) > cv.maxCacheSize {
cv.evictOldest()
}
}
// invalidateCache removes a token from the cache
func (cv *CachedValidator) invalidateCache(tokenID string) {
cv.cacheMutex.Lock()
delete(cv.cache, tokenID)
// Remove from heap if present
if entry, ok := cv.heapIndex[tokenID]; ok {
idx := entry.index
heap.Remove(&cv.heap, idx)
delete(cv.heapIndex, tokenID)
}
cv.cacheMutex.Unlock()
// Note: In production, cache and heap sizes should always be consistent
}
// evictOldest removes the single oldest entry from the cache
func (cv *CachedValidator) evictOldest() {
if cv.heap.Len() == 0 {
return
}
entry := heap.Pop(&cv.heap).(*cacheEntry)
// Remove from heapIndex
delete(cv.heapIndex, entry.tokenID)
// Remove from cache
delete(cv.cache, entry.tokenID)
cv.statsMutex.Lock()
cv.evictions++
cv.statsMutex.Unlock()
}
// startCleanup periodically cleans up expired entries from the cache
func (cv *CachedValidator) startCleanup(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
cv.cleanup()
}
}
// cleanup removes expired entries from the cache
func (cv *CachedValidator) cleanup() {
now := time.Now()
cv.cacheMutex.Lock()
defer cv.cacheMutex.Unlock()
for k, v := range cv.cache {
if now.After(v.ValidUntil) {
delete(cv.cache, k)
cv.statsMutex.Lock()
cv.evictions++
cv.statsMutex.Unlock()
}
}
}
// ClearCache removes all entries from the cache
func (cv *CachedValidator) ClearCache() {
cv.cacheMutex.Lock()
cv.cache = make(map[string]CacheEntry)
cv.cacheMutex.Unlock()
}
// GetCacheStats returns statistics about the cache
func (cv *CachedValidator) GetCacheStats() (hits, misses, evictions, size int) {
cv.statsMutex.Lock()
hits = cv.hits
misses = cv.misses
evictions = cv.evictions
cv.statsMutex.Unlock()
cv.cacheMutex.RLock()
size = len(cv.cache)
cv.cacheMutex.RUnlock()
return
}
// GetCacheInfo returns a formatted string with cache statistics
func (cv *CachedValidator) GetCacheInfo() string {
hits, misses, evictions, size := cv.GetCacheStats()
total := hits + misses
hitRate := 0.0
if total > 0 {
hitRate = float64(hits) / float64(total) * 100
}
return fmt.Sprintf(
"Cache Stats:\n"+
" Size: %d (max: %d)\n"+
" Hits: %d (%.1f%%)\n"+
" Misses: %d\n"+
" Evictions: %d\n"+
" TTL: %s",
size, cv.maxCacheSize, hits, hitRate, misses, evictions, cv.cacheTTL,
)
}
package token
import (
"errors"
"time"
)
var (
// ErrInvalidDuration is returned when an expiration duration is invalid
ErrInvalidDuration = errors.New("invalid duration")
// ErrExpirationInPast is returned when an expiration time is in the past
ErrExpirationInPast = errors.New("expiration time is in the past")
)
// Common expiration durations
const (
OneHour = time.Hour
OneDay = 24 * time.Hour
OneWeek = 7 * 24 * time.Hour
ThirtyDays = 30 * 24 * time.Hour
NinetyDays = 90 * 24 * time.Hour
NoExpiration = time.Duration(0)
MaxDuration = time.Duration(1<<63 - 1)
)
// CalculateExpiration returns an expiration time based on the current time and the provided duration.
// If duration is 0 or negative, it returns nil (no expiration).
func CalculateExpiration(duration time.Duration) *time.Time {
if duration <= 0 {
return nil // No expiration
}
expiry := time.Now().Add(duration)
return &expiry
}
// CalculateExpirationFrom returns an expiration time based on the provided start time and duration.
// If duration is 0 or negative, it returns nil (no expiration).
func CalculateExpirationFrom(startTime time.Time, duration time.Duration) *time.Time {
if duration <= 0 {
return nil // No expiration
}
expiry := startTime.Add(duration)
return &expiry
}
// ValidateExpiration checks if the given expiration time is valid (nil or in the future).
func ValidateExpiration(expiresAt *time.Time) error {
if expiresAt == nil {
return nil // No expiration is valid
}
if expiresAt.Before(time.Now()) {
return ErrExpirationInPast
}
return nil
}
// IsExpired checks if a token with the given expiration time is expired.
// If expiresAt is nil, the token never expires.
func IsExpired(expiresAt *time.Time) bool {
if expiresAt == nil {
return false // No expiration
}
return time.Now().After(*expiresAt)
}
// TimeUntilExpiration returns the duration until the token expires.
// If expiresAt is nil, it returns the max possible duration (effectively "never").
func TimeUntilExpiration(expiresAt *time.Time) time.Duration {
if expiresAt == nil {
return MaxDuration // Max duration, practically "never"
}
until := time.Until(*expiresAt)
if until < 0 {
return 0 // Already expired
}
return until
}
// ExpiresWithin checks if the token will expire within the given duration.
// If expiresAt is nil, it returns false (never expires).
func ExpiresWithin(expiresAt *time.Time, duration time.Duration) bool {
if expiresAt == nil {
return false // No expiration
}
expiresIn := time.Until(*expiresAt)
return expiresIn >= 0 && expiresIn <= duration
}
// FormatExpirationTime returns a human-readable string for the expiration time.
// If expiresAt is nil, it returns "Never expires".
func FormatExpirationTime(expiresAt *time.Time) string {
if expiresAt == nil {
return "Never expires"
}
if expiresAt.Before(time.Now()) {
return "Expired"
}
return expiresAt.Format(time.RFC3339)
}
package token
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
)
// ManagerStore is a composite interface for all required store operations
// ManagerStore embeds all store interfaces required by Manager
// (TokenStore, RevocationStore, RateLimitStore)
//
//go:generate mockgen -destination=mock_managerstore.go -package=token . ManagerStore
type ManagerStore interface {
TokenStore
RevocationStore
RateLimitStore
}
// Manager provides a unified interface for all token operations
type Manager struct {
validator TokenValidator
revoker *Revoker
limiter *StandardRateLimiter
generator *TokenGenerator
store ManagerStore // Underlying store (must implement all store interfaces)
useCaching bool
}
// NewManager creates a new token manager with the given store
func NewManager(store ManagerStore, useCaching bool) (*Manager, error) {
// No need to check interfaces, type system enforces it now
// Create components
baseValidator := NewValidator(store)
var validator TokenValidator = baseValidator
if useCaching {
validator = NewCachedValidator(baseValidator)
}
revoker := NewRevoker(store)
limiter := NewRateLimiter(store)
generator := NewTokenGenerator()
return &Manager{
validator: validator,
revoker: revoker,
limiter: limiter,
generator: generator,
store: store,
useCaching: useCaching,
}, nil
}
// CreateToken generates a new token with the specified options
func (m *Manager) CreateToken(ctx context.Context, projectID string, options TokenOptions) (TokenData, error) {
// Generate a new token
tokenStr, expiresAt, maxRequests, err := m.generator.GenerateWithOptions(options.Expiration, options.MaxRequests)
if err != nil {
return TokenData{}, fmt.Errorf("failed to generate token: %w", err)
}
// Create token data
now := time.Now()
token := TokenData{
Token: tokenStr,
ProjectID: projectID,
ExpiresAt: expiresAt,
IsActive: true,
RequestCount: 0,
MaxRequests: maxRequests,
CreatedAt: now,
}
// For tests only: if the store is a mock with AddToken, call it
if mockStore, ok := any(m.store).(interface{ AddToken(string, TokenData) }); ok {
mockStore.AddToken(token.Token, token)
}
return token, nil
}
// ValidateToken validates a token without incrementing usage
func (m *Manager) ValidateToken(ctx context.Context, tokenID string) (string, error) {
return m.validator.ValidateToken(ctx, tokenID)
}
// ValidateTokenWithTracking validates a token and increments usage count
func (m *Manager) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) {
return m.validator.ValidateTokenWithTracking(ctx, tokenID)
}
// RevokeToken revokes a token
func (m *Manager) RevokeToken(ctx context.Context, tokenID string) error {
return m.revoker.RevokeToken(ctx, tokenID)
}
// DeleteToken completely removes a token
func (m *Manager) DeleteToken(ctx context.Context, tokenID string) error {
return m.revoker.DeleteToken(ctx, tokenID)
}
// RevokeExpiredTokens revokes all expired tokens
func (m *Manager) RevokeExpiredTokens(ctx context.Context) (int, error) {
return m.revoker.RevokeExpiredTokens(ctx)
}
// RevokeProjectTokens revokes all tokens for a project
func (m *Manager) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
return m.revoker.RevokeProjectTokens(ctx, projectID)
}
// GetTokenInfo gets detailed information about a token
func (m *Manager) GetTokenInfo(ctx context.Context, tokenID string) (*TokenInfo, error) {
tokenData, err := m.store.GetTokenByID(ctx, tokenID)
if err != nil {
return nil, err
}
info := GetTokenInfo(tokenData)
return &info, nil
}
// GetTokenStats gets statistics about token usage
func (m *Manager) GetTokenStats(ctx context.Context, tokenID string) (*TokenStats, error) {
tokenData, err := m.store.GetTokenByID(ctx, tokenID)
if err != nil {
return nil, err
}
var remaining int
if tokenData.MaxRequests != nil {
remaining = *tokenData.MaxRequests - tokenData.RequestCount
if remaining < 0 {
remaining = 0
}
} else {
remaining = -1 // Unlimited
}
var timeRemaining time.Duration
if tokenData.ExpiresAt != nil {
timeRemaining = time.Until(*tokenData.ExpiresAt)
if timeRemaining < 0 {
timeRemaining = 0
}
} else {
timeRemaining = -1 // No expiration
}
stats := &TokenStats{
Token: tokenData.Token,
RequestCount: tokenData.RequestCount,
RemainingCount: remaining,
LastUsed: tokenData.LastUsedAt,
TimeRemaining: timeRemaining,
IsValid: tokenData.IsValid(),
ObfuscatedToken: ObfuscateToken(tokenData.Token),
}
return stats, nil
}
// UpdateToken updates an existing token in the store
func (m *Manager) UpdateToken(ctx context.Context, token TokenData) error {
// Validate token format first
if err := token.ValidateFormat(); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
return m.store.UpdateToken(ctx, token)
}
// UpdateTokenLimit updates the maximum allowed requests for a token
func (m *Manager) UpdateTokenLimit(ctx context.Context, tokenID string, maxRequests *int) error {
return m.limiter.UpdateLimit(ctx, tokenID, maxRequests)
}
// ResetTokenUsage resets the usage count for a token
func (m *Manager) ResetTokenUsage(ctx context.Context, tokenID string) error {
return m.limiter.ResetUsage(ctx, tokenID)
}
// IsTokenValid checks if a token is valid
func (m *Manager) IsTokenValid(ctx context.Context, tokenID string) bool {
_, err := m.validator.ValidateToken(ctx, tokenID)
return err == nil
}
// StartAutomaticRevocation starts automatic revocation of expired tokens
func (m *Manager) StartAutomaticRevocation(interval time.Duration, logger *zap.Logger) *AutomaticRevocation {
autoRevoke := NewAutomaticRevocation(m.revoker, interval, logger)
autoRevoke.Start()
return autoRevoke
}
// GetCacheInfo returns information about the token validation cache if caching is enabled
func (m *Manager) GetCacheInfo() (string, bool) {
if !m.useCaching {
return "Caching disabled", false
}
if cachedValidator, ok := m.validator.(*CachedValidator); ok {
return cachedValidator.GetCacheInfo(), true
}
return "Cache info not available", false
}
// WithGeneratorOptions configures the token generator with new options
func (m *Manager) WithGeneratorOptions(expiration time.Duration, maxRequests *int) *Manager {
generator := m.generator.WithExpiration(expiration)
if maxRequests != nil {
generator = generator.WithMaxRequests(*maxRequests)
}
m.generator = generator
return m
}
// TokenOptions contains options for token creation
type TokenOptions struct {
// Expiration duration (0 for no expiration)
Expiration time.Duration
// Maximum requests (nil for no limit)
MaxRequests *int
// Custom metadata (implementation-dependent)
Metadata map[string]string
}
// TokenStats contains statistics about token usage
type TokenStats struct {
Token string
ObfuscatedToken string
RequestCount int
RemainingCount int // -1 means unlimited
LastUsed *time.Time
TimeRemaining time.Duration // -1 means no expiration
IsValid bool
}
package token
import (
"context"
"errors"
"fmt"
"sync"
"time"
)
var (
// ErrRateLimitExceeded is returned when a token exceeds its rate limit
ErrRateLimitExceeded = errors.New("rate limit exceeded")
// ErrLimitOperation is returned when an operation on rate limits fails
ErrLimitOperation = errors.New("limit operation failed")
)
// RateLimiter defines the interface for rate limiting
type RateLimiter interface {
// AllowRequest checks if a token is within its rate limits and updates usage
AllowRequest(ctx context.Context, tokenID string) error
// GetRemainingRequests returns the number of remaining requests for a token
GetRemainingRequests(ctx context.Context, tokenID string) (int, error)
// ResetUsage resets the usage counter for a token
ResetUsage(ctx context.Context, tokenID string) error
// UpdateLimit updates the maximum allowed requests for a token
UpdateLimit(ctx context.Context, tokenID string, maxRequests *int) error
}
// RateLimitStore defines the interface for rate limit persistence
type RateLimitStore interface {
// GetTokenByID retrieves a token by its ID
GetTokenByID(ctx context.Context, tokenID string) (TokenData, error)
// IncrementTokenUsage increments the usage count for a token
IncrementTokenUsage(ctx context.Context, tokenID string) error
// ResetTokenUsage resets the usage count for a token to zero
ResetTokenUsage(ctx context.Context, tokenID string) error
// UpdateTokenLimit updates the maximum allowed requests for a token
UpdateTokenLimit(ctx context.Context, tokenID string, maxRequests *int) error
}
// StandardRateLimiter implements RateLimiter using a persistent store
type StandardRateLimiter struct {
store RateLimitStore
}
// NewRateLimiter creates a new StandardRateLimiter with the given store
func NewRateLimiter(store RateLimitStore) *StandardRateLimiter {
return &StandardRateLimiter{
store: store,
}
}
// AllowRequest checks if a token is within its rate limits and updates usage
func (r *StandardRateLimiter) AllowRequest(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Get current token data
token, err := r.store.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to retrieve token: %w", err)
}
// Check if token has a rate limit
if token.MaxRequests != nil {
// Check if token has exceeded its rate limit
if token.RequestCount >= *token.MaxRequests {
return ErrRateLimitExceeded
}
}
// Increment usage count
if err := r.store.IncrementTokenUsage(ctx, tokenID); err != nil {
return fmt.Errorf("failed to update token usage: %w", err)
}
return nil
}
// GetRemainingRequests returns the number of remaining requests for a token
func (r *StandardRateLimiter) GetRemainingRequests(ctx context.Context, tokenID string) (int, error) {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return 0, fmt.Errorf("invalid token format: %w", err)
}
// Get current token data
token, err := r.store.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return 0, ErrTokenNotFound
}
return 0, fmt.Errorf("failed to retrieve token: %w", err)
}
// If token has no limit, return a high number
if token.MaxRequests == nil {
return 1000000000, nil // Unlimited
}
// Calculate remaining requests
remaining := *token.MaxRequests - token.RequestCount
if remaining < 0 {
remaining = 0
}
return remaining, nil
}
// ResetUsage resets the usage counter for a token
func (r *StandardRateLimiter) ResetUsage(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Reset token usage
if err := r.store.ResetTokenUsage(ctx, tokenID); err != nil {
if errors.Is(err, ErrTokenNotFound) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to reset token usage: %w", err)
}
return nil
}
// UpdateLimit updates the maximum allowed requests for a token
func (r *StandardRateLimiter) UpdateLimit(ctx context.Context, tokenID string, maxRequests *int) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Update token limit
if err := r.store.UpdateTokenLimit(ctx, tokenID, maxRequests); err != nil {
if errors.Is(err, ErrTokenNotFound) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to update token limit: %w", err)
}
return nil
}
// MemoryRateLimiter implements in-memory rate limiting with token bucket algorithm
type MemoryRateLimiter struct {
// Mapping from token ID to rate limit data
limits map[string]*TokenBucket
limitsMutex sync.RWMutex
// Default rate (tokens per second)
defaultRate float64
// Default capacity
defaultCapacity int
}
// TokenBucket represents a token bucket rate limiter for a specific token
type TokenBucket struct {
// Current number of tokens in the bucket
tokens float64
// Maximum number of tokens the bucket can hold
capacity int
// Rate at which tokens are added to the bucket (tokens per second)
rate float64
// Last time the bucket was refilled
lastRefill time.Time
// Mutex to protect the bucket from concurrent access
mutex sync.Mutex
}
// NewMemoryRateLimiter creates a new in-memory rate limiter with token bucket algorithm
func NewMemoryRateLimiter(defaultRate float64, defaultCapacity int) *MemoryRateLimiter {
return &MemoryRateLimiter{
limits: make(map[string]*TokenBucket),
defaultRate: defaultRate,
defaultCapacity: defaultCapacity,
}
}
// Allow checks if a request is allowed for a token
func (m *MemoryRateLimiter) Allow(tokenID string) bool {
m.limitsMutex.RLock()
bucket, exists := m.limits[tokenID]
m.limitsMutex.RUnlock()
if !exists {
// Create a new bucket for this token
bucket = &TokenBucket{
tokens: float64(m.defaultCapacity),
capacity: m.defaultCapacity,
rate: m.defaultRate,
lastRefill: time.Now(),
}
m.limitsMutex.Lock()
m.limits[tokenID] = bucket
m.limitsMutex.Unlock()
return true // First request is always allowed
}
// Try to consume a token from the bucket
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
// Refill the bucket based on elapsed time
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
bucket.lastRefill = now
bucket.tokens += elapsed * bucket.rate
if bucket.tokens > float64(bucket.capacity) {
bucket.tokens = float64(bucket.capacity)
}
// Check if we have a token to consume
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0
return true
}
return false
}
// SetLimit sets the rate limit for a specific token
func (m *MemoryRateLimiter) SetLimit(tokenID string, rate float64, capacity int) {
m.limitsMutex.Lock()
defer m.limitsMutex.Unlock()
bucket, exists := m.limits[tokenID]
if !exists {
bucket = &TokenBucket{
tokens: float64(capacity),
capacity: capacity,
rate: rate,
lastRefill: time.Now(),
}
m.limits[tokenID] = bucket
return
}
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
// Refill first to avoid losing tokens
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
bucket.lastRefill = now
bucket.tokens += elapsed * bucket.rate
// Update rate and capacity
bucket.rate = rate
// Adjust tokens if capacity changed
if capacity < bucket.capacity && bucket.tokens > float64(capacity) {
bucket.tokens = float64(capacity)
}
bucket.capacity = capacity
}
// GetLimit gets the current rate limit and capacity for a token
func (m *MemoryRateLimiter) GetLimit(tokenID string) (float64, int, bool) {
m.limitsMutex.RLock()
defer m.limitsMutex.RUnlock()
bucket, exists := m.limits[tokenID]
if !exists {
return m.defaultRate, m.defaultCapacity, false
}
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
return bucket.rate, bucket.capacity, true
}
// Reset resets the rate limit for a token
func (m *MemoryRateLimiter) Reset(tokenID string) {
m.limitsMutex.Lock()
defer m.limitsMutex.Unlock()
bucket, exists := m.limits[tokenID]
if !exists {
return
}
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
bucket.tokens = float64(bucket.capacity)
bucket.lastRefill = time.Now()
}
// Remove removes rate limit data for a token
func (m *MemoryRateLimiter) Remove(tokenID string) {
m.limitsMutex.Lock()
defer m.limitsMutex.Unlock()
delete(m.limits, tokenID)
}
package token
import (
"context"
"time"
"github.com/redis/go-redis/v9"
)
// RedisGoRateLimitAdapter adapts go-redis/v9 Client to the RedisRateLimitClient interface.
type RedisGoRateLimitAdapter struct {
Client *redis.Client
}
// NewRedisGoRateLimitAdapter creates a new adapter for rate limiting operations
func NewRedisGoRateLimitAdapter(client *redis.Client) *RedisGoRateLimitAdapter {
return &RedisGoRateLimitAdapter{Client: client}
}
// Incr atomically increments a key and returns the new value
func (a *RedisGoRateLimitAdapter) Incr(ctx context.Context, key string) (int64, error) {
return a.Client.Incr(ctx, key).Result()
}
// Get retrieves the value of a key
func (a *RedisGoRateLimitAdapter) Get(ctx context.Context, key string) (string, error) {
result, err := a.Client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
}
return result, err
}
// Set sets the value of a key
func (a *RedisGoRateLimitAdapter) Set(ctx context.Context, key, value string) error {
return a.Client.Set(ctx, key, value, 0).Err()
}
// Expire sets a TTL on a key
func (a *RedisGoRateLimitAdapter) Expire(ctx context.Context, key string, expiration time.Duration) error {
return a.Client.Expire(ctx, key, expiration).Err()
}
// SetNX sets a key only if it doesn't exist (for distributed locking)
func (a *RedisGoRateLimitAdapter) SetNX(ctx context.Context, key string, value string, expiration time.Duration) (bool, error) {
return a.Client.SetNX(ctx, key, value, expiration).Result()
}
// Del deletes a key
func (a *RedisGoRateLimitAdapter) Del(ctx context.Context, key string) error {
return a.Client.Del(ctx, key).Err()
}
package token
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strconv"
"sync"
"time"
)
// ErrRedisUnavailable is returned when Redis is unavailable and fallback is disabled
var ErrRedisUnavailable = errors.New("redis unavailable for rate limiting")
// RedisRateLimitClient defines the Redis operations needed for distributed rate limiting.
// This is a subset of the eventbus.RedisClient interface focused on rate limiting operations.
type RedisRateLimitClient interface {
// Incr atomically increments a key and returns the new value
Incr(ctx context.Context, key string) (int64, error)
// Get retrieves the value of a key
Get(ctx context.Context, key string) (string, error)
// Set sets the value of a key
Set(ctx context.Context, key, value string) error
// Expire sets a TTL on a key
Expire(ctx context.Context, key string, expiration time.Duration) error
// SetNX sets a key only if it doesn't exist.
// Included for interface compatibility with eventbus.RedisClient and potential future enhancements.
SetNX(ctx context.Context, key string, value string, expiration time.Duration) (bool, error)
// Del deletes a key
Del(ctx context.Context, key string) error
}
// RedisRateLimiterConfig contains configuration for the Redis rate limiter
type RedisRateLimiterConfig struct {
// KeyPrefix is the prefix for all Redis keys used by the rate limiter
KeyPrefix string
// KeyHashSecret is the HMAC secret for hashing token IDs in Redis keys.
// When set, token IDs are hashed using HMAC-SHA256 to prevent cleartext exposure.
// This is recommended for production deployments to enhance security.
KeyHashSecret []byte
// DefaultWindowDuration is the default sliding window duration for rate limiting
DefaultWindowDuration time.Duration
// DefaultMaxRequests is the default maximum requests per window
DefaultMaxRequests int
// EnableFallback enables fallback to in-memory rate limiting when Redis is unavailable
EnableFallback bool
// FallbackRate is the rate for fallback in-memory token bucket (tokens per second)
FallbackRate float64
// FallbackCapacity is the capacity for fallback in-memory token bucket
FallbackCapacity int
}
// DefaultRedisRateLimiterConfig returns default configuration
func DefaultRedisRateLimiterConfig() RedisRateLimiterConfig {
return RedisRateLimiterConfig{
KeyPrefix: "ratelimit:",
DefaultWindowDuration: time.Minute,
DefaultMaxRequests: 60,
EnableFallback: true,
FallbackRate: 1.0, // 1 token per second
FallbackCapacity: 10,
}
}
// RedisRateLimiter implements distributed rate limiting using Redis.
// It uses a sliding window counter algorithm with Redis INCR for atomic operations.
type RedisRateLimiter struct {
client RedisRateLimitClient
config RedisRateLimiterConfig
fallback *MemoryRateLimiter
// Track Redis availability
redisAvailable bool
redisAvailableMu sync.RWMutex
// Per-token limits (optional override of defaults)
tokenLimits map[string]*TokenRateLimit
tokenLimitsMu sync.RWMutex
}
// TokenRateLimit holds rate limit configuration for a specific token
type TokenRateLimit struct {
MaxRequests int
WindowDuration time.Duration
}
// NewRedisRateLimiter creates a new distributed rate limiter using Redis
func NewRedisRateLimiter(client RedisRateLimitClient, config RedisRateLimiterConfig) *RedisRateLimiter {
limiter := &RedisRateLimiter{
client: client,
config: config,
redisAvailable: true,
tokenLimits: make(map[string]*TokenRateLimit),
}
// Create fallback in-memory limiter if enabled
if config.EnableFallback {
limiter.fallback = NewMemoryRateLimiter(config.FallbackRate, config.FallbackCapacity)
}
return limiter
}
// buildKey constructs the Redis key for a token's rate limit counter.
// If KeyHashSecret is configured, the token ID is hashed using HMAC-SHA256
// to prevent cleartext exposure in Redis keys.
func (r *RedisRateLimiter) buildKey(tokenID string, windowStart int64) string {
keyID := tokenID
if len(r.config.KeyHashSecret) > 0 {
keyID = hashTokenID(tokenID, r.config.KeyHashSecret)
}
return fmt.Sprintf("%s%s:%d", r.config.KeyPrefix, keyID, windowStart)
}
// hashTokenID generates a non-reversible identifier from a token ID using HMAC-SHA256.
// Returns the first 16 hex characters of the HMAC for brevity while maintaining uniqueness.
func hashTokenID(tokenID string, secret []byte) string {
h := hmac.New(sha256.New, secret)
h.Write([]byte(tokenID))
return hex.EncodeToString(h.Sum(nil))[:16]
}
// getWindowStart returns the start timestamp for the current window
func (r *RedisRateLimiter) getWindowStart(windowDuration time.Duration) int64 {
now := time.Now()
return now.Truncate(windowDuration).Unix()
}
// getTokenLimit returns the rate limit for a token, using defaults if not set
func (r *RedisRateLimiter) getTokenLimit(tokenID string) (int, time.Duration) {
r.tokenLimitsMu.RLock()
limit, exists := r.tokenLimits[tokenID]
r.tokenLimitsMu.RUnlock()
if exists && limit != nil {
return limit.MaxRequests, limit.WindowDuration
}
return r.config.DefaultMaxRequests, r.config.DefaultWindowDuration
}
// Allow checks if a request from the given token should be allowed.
// Returns true if the request is within rate limits, false otherwise.
func (r *RedisRateLimiter) Allow(ctx context.Context, tokenID string) (bool, error) {
maxRequests, windowDuration := r.getTokenLimit(tokenID)
windowStart := r.getWindowStart(windowDuration)
key := r.buildKey(tokenID, windowStart)
// Check Redis availability
r.redisAvailableMu.RLock()
available := r.redisAvailable
r.redisAvailableMu.RUnlock()
if !available {
return r.handleFallback(tokenID)
}
// Try to increment counter atomically in Redis
count, err := r.client.Incr(ctx, key)
if err != nil {
// Redis operation failed
r.markRedisUnavailable()
return r.handleFallback(tokenID)
}
// Mark Redis as available (successful operation)
r.markRedisAvailable()
// Set expiration on first increment (count == 1)
if count == 1 {
// Set TTL slightly longer than window to handle edge cases
ttl := windowDuration + time.Second
// Ignore expire errors for graceful degradation; orphaned keys will be cleaned up by Redis eventually.
_ = r.client.Expire(ctx, key, ttl)
}
// Check if request is within limit
return count <= int64(maxRequests), nil
}
// handleFallback handles rate limiting when Redis is unavailable
func (r *RedisRateLimiter) handleFallback(tokenID string) (bool, error) {
if !r.config.EnableFallback || r.fallback == nil {
return false, ErrRedisUnavailable
}
return r.fallback.Allow(tokenID), nil
}
// markRedisUnavailable marks Redis as unavailable
func (r *RedisRateLimiter) markRedisUnavailable() {
r.redisAvailableMu.Lock()
r.redisAvailable = false
r.redisAvailableMu.Unlock()
}
// markRedisAvailable marks Redis as available
func (r *RedisRateLimiter) markRedisAvailable() {
r.redisAvailableMu.Lock()
r.redisAvailable = true
r.redisAvailableMu.Unlock()
}
// GetRemainingRequests returns the number of remaining requests for a token in the current window
func (r *RedisRateLimiter) GetRemainingRequests(ctx context.Context, tokenID string) (int, error) {
maxRequests, windowDuration := r.getTokenLimit(tokenID)
windowStart := r.getWindowStart(windowDuration)
key := r.buildKey(tokenID, windowStart)
// Check Redis availability
r.redisAvailableMu.RLock()
available := r.redisAvailable
r.redisAvailableMu.RUnlock()
if !available {
if r.config.EnableFallback {
// Return a reasonable default when in fallback mode
return r.config.FallbackCapacity, nil
}
return 0, ErrRedisUnavailable
}
// Get current count from Redis
countStr, err := r.client.Get(ctx, key)
if err != nil {
// Redis error - mark unavailable and use fallback
r.markRedisUnavailable()
if r.config.EnableFallback {
return r.config.FallbackCapacity, nil
}
return 0, fmt.Errorf("failed to get rate limit counter: %w", err)
}
// Key doesn't exist (empty string returned, no error) - all requests remain
if countStr == "" {
return maxRequests, nil
}
// Parse count using strconv.Atoi for idiomatic integer parsing
count, parseErr := strconv.Atoi(countStr)
if parseErr != nil {
count = 0
}
remaining := maxRequests - count
if remaining < 0 {
remaining = 0
}
return remaining, nil
}
// SetTokenLimit sets a custom rate limit for a specific token
func (r *RedisRateLimiter) SetTokenLimit(tokenID string, maxRequests int, windowDuration time.Duration) {
r.tokenLimitsMu.Lock()
r.tokenLimits[tokenID] = &TokenRateLimit{
MaxRequests: maxRequests,
WindowDuration: windowDuration,
}
r.tokenLimitsMu.Unlock()
}
// RemoveTokenLimit removes the custom rate limit for a token (falls back to defaults)
func (r *RedisRateLimiter) RemoveTokenLimit(tokenID string) {
r.tokenLimitsMu.Lock()
delete(r.tokenLimits, tokenID)
r.tokenLimitsMu.Unlock()
}
// ResetTokenUsage resets the rate limit counter for a token
func (r *RedisRateLimiter) ResetTokenUsage(ctx context.Context, tokenID string) error {
_, windowDuration := r.getTokenLimit(tokenID)
windowStart := r.getWindowStart(windowDuration)
key := r.buildKey(tokenID, windowStart)
// Check Redis availability
r.redisAvailableMu.RLock()
available := r.redisAvailable
r.redisAvailableMu.RUnlock()
if !available {
if r.config.EnableFallback && r.fallback != nil {
r.fallback.Reset(tokenID)
return nil
}
return ErrRedisUnavailable
}
if err := r.client.Del(ctx, key); err != nil {
r.markRedisUnavailable()
if r.config.EnableFallback && r.fallback != nil {
r.fallback.Reset(tokenID)
return nil
}
return fmt.Errorf("failed to reset rate limit counter: %w", err)
}
r.markRedisAvailable()
return nil
}
// IsRedisAvailable returns whether Redis is currently available
func (r *RedisRateLimiter) IsRedisAvailable() bool {
r.redisAvailableMu.RLock()
defer r.redisAvailableMu.RUnlock()
return r.redisAvailable
}
// CheckRedisHealth performs a health check on the Redis connection
func (r *RedisRateLimiter) CheckRedisHealth(ctx context.Context) error {
// Try a simple operation to verify Redis is working
testKey := r.config.KeyPrefix + "_healthcheck"
if err := r.client.Set(ctx, testKey, "1"); err != nil {
r.markRedisUnavailable()
return fmt.Errorf("redis health check failed: %w", err)
}
// Cleanup
_ = r.client.Del(ctx, testKey)
r.markRedisAvailable()
return nil
}
package token
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
)
var (
// ErrTokenAlreadyRevoked is returned when trying to revoke an already revoked token
ErrTokenAlreadyRevoked = errors.New("token is already revoked")
)
// RevocationStore defines the interface for token revocation
type RevocationStore interface {
// RevokeToken disables a token by setting is_active to false
RevokeToken(ctx context.Context, tokenID string) error
// DeleteToken completely removes a token from storage
DeleteToken(ctx context.Context, tokenID string) error
// RevokeBatchTokens revokes multiple tokens at once
RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error)
// RevokeProjectTokens revokes all tokens for a project
RevokeProjectTokens(ctx context.Context, projectID string) (int, error)
// RevokeExpiredTokens revokes all tokens that have expired
RevokeExpiredTokens(ctx context.Context) (int, error)
}
// Revoker provides methods for token revocation
type Revoker struct {
store RevocationStore
}
// NewRevoker creates a new token revoker with the given store
func NewRevoker(store RevocationStore) *Revoker {
return &Revoker{
store: store,
}
}
// RevokeToken soft revokes a token by setting is_active to false
func (r *Revoker) RevokeToken(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Attempt to revoke the token
err := r.store.RevokeToken(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return fmt.Errorf("cannot revoke: %w", err)
}
if errors.Is(err, ErrTokenAlreadyRevoked) {
return err
}
return fmt.Errorf("failed to revoke token: %w", err)
}
return nil
}
// DeleteToken completely removes a token from storage (hard delete)
func (r *Revoker) DeleteToken(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Attempt to delete the token
err := r.store.DeleteToken(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return fmt.Errorf("cannot delete: %w", err)
}
return fmt.Errorf("failed to delete token: %w", err)
}
return nil
}
// RevokeBatchTokens revokes multiple tokens in a single operation
func (r *Revoker) RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error) {
if len(tokenIDs) == 0 {
return 0, nil
}
// Validate all token formats first
for _, tokenID := range tokenIDs {
if err := ValidateTokenFormat(tokenID); err != nil {
return 0, fmt.Errorf("invalid token format for %s: %w", tokenID, err)
}
}
// Revoke the tokens
count, err := r.store.RevokeBatchTokens(ctx, tokenIDs)
if err != nil {
return 0, fmt.Errorf("failed to revoke tokens in batch: %w", err)
}
return count, nil
}
// RevokeProjectTokens revokes all tokens for a project
func (r *Revoker) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
if projectID == "" {
return 0, errors.New("project ID cannot be empty")
}
count, err := r.store.RevokeProjectTokens(ctx, projectID)
if err != nil {
return 0, fmt.Errorf("failed to revoke project tokens: %w", err)
}
return count, nil
}
// RevokeExpiredTokens revokes all tokens that have expired
func (r *Revoker) RevokeExpiredTokens(ctx context.Context) (int, error) {
count, err := r.store.RevokeExpiredTokens(ctx)
if err != nil {
return 0, fmt.Errorf("failed to revoke expired tokens: %w", err)
}
return count, nil
}
// AutomaticRevocation sets up periodic revocation of expired tokens
type AutomaticRevocation struct {
revoker *Revoker
interval time.Duration
stopChan chan struct{}
stoppedChan chan struct{}
logger *zap.Logger
}
// NewAutomaticRevocation creates a new automatic token revocation
func NewAutomaticRevocation(revoker *Revoker, interval time.Duration, logger *zap.Logger) *AutomaticRevocation {
return &AutomaticRevocation{
revoker: revoker,
interval: interval,
stopChan: make(chan struct{}),
stoppedChan: make(chan struct{}),
logger: logger,
}
}
// Start begins the automatic revocation of expired tokens
func (a *AutomaticRevocation) Start() {
go func() {
ticker := time.NewTicker(a.interval)
defer ticker.Stop()
defer close(a.stoppedChan)
for {
select {
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
count, err := a.revoker.RevokeExpiredTokens(ctx)
if err != nil {
a.logger.Error("Failed to automatically revoke expired tokens", zap.Error(err))
} else if count > 0 {
a.logger.Info("Automatically revoked expired tokens", zap.Int("count", count))
}
cancel()
case <-a.stopChan:
return
}
}
}()
}
// Stop halts the automatic revocation
func (a *AutomaticRevocation) Stop() {
close(a.stopChan)
<-a.stoppedChan
}
package token
import (
"encoding/base64"
"errors"
"fmt"
"regexp"
"strings"
"github.com/google/uuid"
)
const (
// TokenPrefix is the prefix for all tokens
TokenPrefix = "sk-"
// TokenRegex is the regular expression for validating token format
TokenRegexPattern = `^sk-[A-Za-z0-9_-]{22}$`
)
var (
// TokenRegex is the compiled regular expression for token format validation
TokenRegex = regexp.MustCompile(TokenRegexPattern)
// ErrInvalidTokenFormat is returned when the token format is invalid
ErrInvalidTokenFormat = errors.New("invalid token format")
// ErrTokenDecodingFailed is returned when the token cannot be decoded
ErrTokenDecodingFailed = errors.New("token decoding failed")
)
// GenerateToken generates a new token with the provided prefix and a UUIDv7.
// The UUIDv7 includes the current timestamp, making tokens time-ordered.
func GenerateToken() (string, error) {
// Generate a UUIDv7 which includes a timestamp component
id, err := uuid.NewV7()
if err != nil {
return "", fmt.Errorf("failed to generate UUID: %w", err)
}
// Convert UUID to Base64 URL-safe encoding
uuidBytes, err := id.MarshalBinary()
if err != nil {
return "", fmt.Errorf("failed to marshal UUID: %w", err)
}
// Use URL-safe base64 encoding without padding
encoded := base64.RawURLEncoding.EncodeToString(uuidBytes)
// Combine prefix with encoded UUID
token := TokenPrefix + encoded
return token, nil
}
// ValidateTokenFormat checks if the given token string follows the expected format.
// It does not check if the token exists or is valid in the database.
func ValidateTokenFormat(token string) error {
// Check format with regex
if !TokenRegex.MatchString(token) {
return ErrInvalidTokenFormat
}
// Attempt to decode the token to ensure it was properly generated
_, err := DecodeToken(token)
if err != nil {
return fmt.Errorf("%w: %v", ErrTokenDecodingFailed, err)
}
return nil
}
// DecodeToken extracts the UUID from a token string.
func DecodeToken(token string) (uuid.UUID, error) {
// Check if the token has the correct prefix
if !strings.HasPrefix(token, TokenPrefix) {
return uuid.UUID{}, ErrInvalidTokenFormat
}
// Remove the prefix
encodedPart := strings.TrimPrefix(token, TokenPrefix)
// Decode from base64
uuidBytes, err := base64.RawURLEncoding.DecodeString(encodedPart)
if err != nil {
return uuid.UUID{}, fmt.Errorf("failed to decode token: %w", err)
}
// Parse the UUID
var id uuid.UUID
if err := id.UnmarshalBinary(uuidBytes); err != nil {
return uuid.UUID{}, fmt.Errorf("failed to unmarshal UUID: %w", err)
}
return id, nil
}
package token
import (
"crypto/rand"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// Constants for token options and generation
const (
// MinTokenLength is the minimum acceptable token length
MinTokenLength = 20
// DefaultTokenExpiration is the default expiration for tokens (30 days)
DefaultTokenExpiration = 30 * 24 * time.Hour
// DefaultMaxRequests is the default maximum requests per token (unlimited)
DefaultMaxRequests = 0 // 0 means unlimited
)
// TokenGenerator is a utility for generating tokens with specific options
type TokenGenerator struct {
// Default expiration time for new tokens
DefaultExpiration time.Duration
// Default maximum requests for new tokens (nil means unlimited)
DefaultMaxRequests *int
}
// NewTokenGenerator creates a new TokenGenerator with default options
func NewTokenGenerator() *TokenGenerator {
return &TokenGenerator{
DefaultExpiration: DefaultTokenExpiration,
DefaultMaxRequests: nil, // Unlimited by default
}
}
// WithExpiration sets the default expiration for new tokens
func (g *TokenGenerator) WithExpiration(expiration time.Duration) *TokenGenerator {
g.DefaultExpiration = expiration
return g
}
// WithMaxRequests sets the default maximum requests for new tokens
func (g *TokenGenerator) WithMaxRequests(maxRequests int) *TokenGenerator {
g.DefaultMaxRequests = &maxRequests
return g
}
// Generate generates a new token with default options
func (g *TokenGenerator) Generate() (string, error) {
return GenerateToken()
}
// GenerateWithOptions generates a new token with specific options
func (g *TokenGenerator) GenerateWithOptions(expiration time.Duration, maxRequests *int) (string, *time.Time, *int, error) {
// Generate token
token, err := GenerateToken()
if err != nil {
return "", nil, nil, err
}
// Calculate expiration
var expiresAt *time.Time
if expiration > 0 {
exp := CalculateExpiration(expiration)
expiresAt = exp
} else if g.DefaultExpiration > 0 {
exp := CalculateExpiration(g.DefaultExpiration)
expiresAt = exp
}
// Determine max requests
var maxReq *int
if maxRequests != nil {
maxReq = maxRequests
} else {
maxReq = g.DefaultMaxRequests
}
return token, expiresAt, maxReq, nil
}
// ExtractTokenFromHeader extracts a token from an HTTP Authorization header
func ExtractTokenFromHeader(header string) (string, bool) {
if header == "" {
return "", false
}
// Check for "Bearer" auth scheme
parts := strings.Split(header, " ")
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return "", false
}
token := parts[1]
if token == "" {
return "", false
}
// Validate the token format
if err := ValidateTokenFormat(token); err != nil {
return "", false
}
return token, true
}
// ExtractTokenFromRequest extracts a token from an HTTP request
func ExtractTokenFromRequest(r *http.Request) (string, bool) {
// Try Authorization header first
token, ok := ExtractTokenFromHeader(r.Header.Get("Authorization"))
if ok {
return token, true
}
// Try X-API-Key header next
apiKey := r.Header.Get("X-API-Key")
if apiKey != "" {
if err := ValidateTokenFormat(apiKey); err == nil {
return apiKey, true
}
}
// Try query parameter last
queryToken := r.URL.Query().Get("token")
if queryToken != "" {
if err := ValidateTokenFormat(queryToken); err == nil {
return queryToken, true
}
}
return "", false
}
// GenerateRandomKey generates a random string suitable for use as an API key
func GenerateRandomKey(length int) (string, error) {
if length < MinTokenLength {
length = MinTokenLength
}
// Character set for random key
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
charsetLen := big.NewInt(int64(len(charset)))
// Build the random string
b := make([]byte, length)
for i := 0; i < length; i++ {
n, err := rand.Int(rand.Reader, charsetLen)
if err != nil {
return "", err
}
b[i] = charset[n.Int64()]
}
return string(b), nil
}
// TruncateToken truncates a token string for display, preserving the prefix and suffix
func TruncateToken(token string, showChars int) string {
if token == "" || len(token) <= showChars*2 {
return token
}
prefix := token[:showChars]
suffix := token[len(token)-showChars:]
return prefix + "..." + suffix
}
// ObfuscateToken partially obfuscates a token for display purposes
func ObfuscateToken(token string) string { return obfuscate.ObfuscateTokenByPrefix(token, TokenPrefix) }
// TokenInfo represents information about a token for display purposes
type TokenInfo struct {
Token string `json:"token"`
ObfuscatedToken string `json:"obfuscated_token"`
CreationTime time.Time `json:"creation_time"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IsActive bool `json:"is_active"`
RequestCount int `json:"request_count"`
MaxRequests *int `json:"max_requests,omitempty"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
TimeRemaining string `json:"time_remaining,omitempty"`
IsValid bool `json:"is_valid"`
}
// GetTokenInfo creates a TokenInfo struct with token details
func GetTokenInfo(token TokenData) TokenInfo {
// Use the canonical creation time from the database
// Note: UUIDv7 embeds a timestamp, but the library does not expose it; CreatedAt is the source of truth
creationTime := token.CreatedAt
info := TokenInfo{
Token: token.Token,
ObfuscatedToken: ObfuscateToken(token.Token),
CreationTime: creationTime,
ExpiresAt: token.ExpiresAt,
IsActive: token.IsActive,
RequestCount: token.RequestCount,
MaxRequests: token.MaxRequests,
LastUsedAt: token.LastUsedAt,
IsValid: token.IsActive && !IsExpired(token.ExpiresAt) && !token.IsRateLimited(),
}
// Calculate time remaining
if token.ExpiresAt != nil && !IsExpired(token.ExpiresAt) {
duration := time.Until(*token.ExpiresAt)
info.TimeRemaining = formatDuration(duration)
}
return info
}
// FormatTokenInfo formats token information as a human-readable string
func FormatTokenInfo(token TokenData) string {
info := GetTokenInfo(token)
var sb strings.Builder
sb.WriteString("Token: " + info.ObfuscatedToken + "\n")
sb.WriteString("Created: " + info.CreationTime.Format(time.RFC3339) + "\n")
if info.ExpiresAt != nil {
sb.WriteString("Expires: " + info.ExpiresAt.Format(time.RFC3339))
if info.TimeRemaining != "" {
sb.WriteString(" (" + info.TimeRemaining + " remaining)")
}
sb.WriteString("\n")
} else {
sb.WriteString("Expires: Never\n")
}
sb.WriteString("Active: " + fmt.Sprintf("%t", info.IsActive) + "\n")
if info.MaxRequests != nil {
sb.WriteString(fmt.Sprintf("Requests: %d / %d\n", info.RequestCount, *info.MaxRequests))
} else {
sb.WriteString(fmt.Sprintf("Requests: %d / ∞\n", info.RequestCount))
}
if info.LastUsedAt != nil {
sb.WriteString("Last Used: " + info.LastUsedAt.Format(time.RFC3339) + "\n")
} else {
sb.WriteString("Last Used: Never\n")
}
sb.WriteString("Valid: " + fmt.Sprintf("%t", info.IsValid) + "\n")
return sb.String()
}
// formatDuration formats a duration in a human-readable way
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%d seconds", int(d.Seconds()))
} else if d < time.Hour {
return fmt.Sprintf("%d minutes", int(d.Minutes()))
} else if d < 24*time.Hour {
return fmt.Sprintf("%d hours", int(d.Hours()))
} else if d < 30*24*time.Hour {
return fmt.Sprintf("%d days", int(d.Hours()/24))
} else if d < 365*24*time.Hour {
return fmt.Sprintf("%d months", int(d.Hours()/(24*30)))
}
return fmt.Sprintf("%d years", int(d.Hours()/(24*365)))
}
package token
import (
"context"
"errors"
"fmt"
"time"
)
// Errors related to token validation
var (
ErrTokenNotFound = errors.New("token not found")
ErrTokenInactive = errors.New("token is inactive")
ErrTokenExpired = errors.New("token has expired")
ErrTokenRateLimit = errors.New("token has reached rate limit")
)
// TokenValidator defines the interface for token validation
type TokenValidator interface {
// ValidateToken validates a token and returns the associated project ID
ValidateToken(ctx context.Context, token string) (string, error)
// ValidateTokenWithTracking validates a token, returns the project ID, and tracks usage
ValidateTokenWithTracking(ctx context.Context, token string) (string, error)
}
// TokenStore defines the interface for token storage and retrieval
type TokenStore interface {
// GetTokenByID retrieves a token by its ID
GetTokenByID(ctx context.Context, tokenID string) (TokenData, error)
// IncrementTokenUsage increments the usage count for a token
IncrementTokenUsage(ctx context.Context, tokenID string) error
// CreateToken creates a new token in the store
CreateToken(ctx context.Context, token TokenData) error
// UpdateToken updates an existing token
UpdateToken(ctx context.Context, token TokenData) error
// ListTokens retrieves all tokens from the store
ListTokens(ctx context.Context) ([]TokenData, error)
// GetTokensByProjectID retrieves all tokens for a project
GetTokensByProjectID(ctx context.Context, projectID string) ([]TokenData, error)
}
// TokenData represents the data associated with a token
type TokenData struct {
Token string // The token ID
ProjectID string // The associated project ID
ExpiresAt *time.Time // When the token expires (nil for no expiration)
IsActive bool // Whether the token is active
DeactivatedAt *time.Time // When the token was deactivated (nil if not deactivated)
RequestCount int // Number of requests made with this token
MaxRequests *int // Maximum number of requests allowed (nil for unlimited)
CreatedAt time.Time // When the token was created
LastUsedAt *time.Time // When the token was last used (nil if never used)
CacheHitCount int // Number of cache hits for this token
}
// IsValid returns true if the token is active, not expired, and not rate limited
func (t *TokenData) IsValid() bool {
return t.IsActive && !IsExpired(t.ExpiresAt) && !t.IsRateLimited()
}
// IsRateLimited returns true if the token has reached its maximum number of requests
func (t *TokenData) IsRateLimited() bool {
if t.MaxRequests == nil {
return false
}
return t.RequestCount >= *t.MaxRequests
}
// ValidateTokenFormat checks if a token has the correct format
func (t *TokenData) ValidateFormat() error {
return ValidateTokenFormat(t.Token)
}
// StandardValidator is a validator that uses a TokenStore for validation
type StandardValidator struct {
store TokenStore
}
// NewValidator creates a new StandardValidator with the given TokenStore
func NewValidator(store TokenStore) *StandardValidator {
return &StandardValidator{
store: store,
}
}
// ValidateToken validates a token without incrementing usage
func (v *StandardValidator) ValidateToken(ctx context.Context, tokenID string) (string, error) {
// First validate the token format
if err := ValidateTokenFormat(tokenID); err != nil {
return "", fmt.Errorf("invalid token format: %w", err)
}
// Retrieve the token from the store
tokenData, err := v.store.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return "", ErrTokenNotFound
}
return "", fmt.Errorf("failed to retrieve token: %w", err)
}
// Check if the token is active
if !tokenData.IsActive {
return "", ErrTokenInactive
}
// Check if the token has expired
if IsExpired(tokenData.ExpiresAt) {
return "", ErrTokenExpired
}
// Check if the token has reached its rate limit
if tokenData.IsRateLimited() {
return "", ErrTokenRateLimit
}
// Token is valid, return the project ID
return tokenData.ProjectID, nil
}
// ValidateTokenWithTracking validates a token and increments its usage count
func (v *StandardValidator) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) {
// Validate the token first
projectID, err := v.ValidateToken(ctx, tokenID)
if err != nil {
return "", err
}
// Increment the token usage
if err := v.store.IncrementTokenUsage(ctx, tokenID); err != nil {
return "", fmt.Errorf("failed to track token usage: %w", err)
}
return projectID, nil
}
// ValidateTokenFormat checks if a token string has the correct format
func ValidateToken(ctx context.Context, validator TokenValidator, tokenID string) (string, error) {
return validator.ValidateToken(ctx, tokenID)
}
// ValidateTokenWithTracking validates a token and tracks its usage
func ValidateTokenWithTracking(ctx context.Context, validator TokenValidator, tokenID string) (string, error) {
return validator.ValidateTokenWithTracking(ctx, tokenID)
}
// Package utils provides common utility functions.
package utils
import (
"crypto/rand"
"encoding/hex"
"fmt"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// GenerateSecureToken generates a secure random token of the given length
func GenerateSecureToken(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("length must be positive")
}
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", fmt.Errorf("failed to generate secure token: %w", err)
}
return hex.EncodeToString(b), nil
}
// GenerateSecureTokenMustSucceed generates a secure random token or panics
// This is useful for initialization code where failure is unrecoverable
func GenerateSecureTokenMustSucceed(length int) string {
token, err := GenerateSecureToken(length)
if err != nil {
panic(err)
}
return token
}
// Deprecated: use obfuscate.ObfuscateTokenGeneric directly at call sites.
func ObfuscateToken(token string) string { return obfuscate.ObfuscateTokenGeneric(token) }