diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 76273cf48..0fe09d5f8 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -100,7 +100,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv }, nil } -func NewStdioMCPServer(cfg github.MCPServerConfig) (*mcp.Server, error) { +func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig) (*mcp.Server, error) { apiHost, err := utils.NewAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) @@ -144,7 +144,7 @@ func NewStdioMCPServer(cfg github.MCPServerConfig) (*mcp.Server, error) { return nil, fmt.Errorf("failed to build inventory: %w", err) } - ghServer, err := github.NewMCPServer(&cfg, deps, inventory) + ghServer, err := github.NewMCPServer(ctx, &cfg, deps, inventory) if err != nil { return nil, fmt.Errorf("failed to create GitHub MCP server: %w", err) } @@ -246,7 +246,7 @@ func RunStdioServer(cfg StdioServerConfig) error { logger.Debug("skipping scope filtering for non-PAT token") } - ghServer, err := NewStdioMCPServer(github.MCPServerConfig{ + ghServer, err := NewStdioMCPServer(ctx, github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, Token: cfg.Token, diff --git a/pkg/context/request.go b/pkg/context/request.go new file mode 100644 index 000000000..8b9169955 --- /dev/null +++ b/pkg/context/request.go @@ -0,0 +1,51 @@ +package context + +import "context" + +// readonlyCtxKey is a context key for read-only mode +type readonlyCtxKey struct{} + +// WithReadonly adds read-only mode state to the context +func WithReadonly(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, readonlyCtxKey{}, enabled) +} + +// IsReadonly retrieves the read-only mode state from the context +func IsReadonly(ctx context.Context) bool { + if enabled, ok := ctx.Value(readonlyCtxKey{}).(bool); ok { + return enabled + } + return false +} + +// toolsetsCtxKey is a context key for the active toolsets +type toolsetsCtxKey struct{} + +// WithToolsets adds the active toolsets to the context +func WithToolsets(ctx context.Context, toolsets []string) context.Context { + return context.WithValue(ctx, toolsetsCtxKey{}, toolsets) +} + +// GetToolsets retrieves the active toolsets from the context +func GetToolsets(ctx context.Context) []string { + if toolsets, ok := ctx.Value(toolsetsCtxKey{}).([]string); ok { + return toolsets + } + return nil +} + +// toolsCtxKey is a context key for tools +type toolsCtxKey struct{} + +// WithTools adds the tools to the context +func WithTools(ctx context.Context, tools []string) context.Context { + return context.WithValue(ctx, toolsCtxKey{}, tools) +} + +// GetTools retrieves the tools from the context +func GetTools(ctx context.Context) []string { + if tools, ok := ctx.Value(toolsCtxKey{}).([]string); ok { + return tools + } + return nil +} diff --git a/pkg/github/server.go b/pkg/github/server.go index 46e1f4f34..bded66776 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -73,7 +73,7 @@ type MCPServerConfig struct { type MCPServerOption func(*mcp.ServerOptions) -func NewMCPServer(cfg *MCPServerConfig, deps ToolDependencies, inventory *inventory.Inventory) (*mcp.Server, error) { +func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependencies, inventory *inventory.Inventory) (*mcp.Server, error) { // Create the MCP server serverOpts := &mcp.ServerOptions{ Instructions: inventory.Instructions(), @@ -110,7 +110,7 @@ func NewMCPServer(cfg *MCPServerConfig, deps ToolDependencies, inventory *invent // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets // is empty - users enable toolsets at runtime via the dynamic tools below (but can // enable toolsets or tools explicitly that do need registration). - inventory.RegisterAll(context.Background(), ghServer, deps) + inventory.RegisterAll(ctx, ghServer, deps) // Register dynamic toolset management tools (enable/disable) - these are separate // meta-tools that control the inventory, not part of the inventory itself diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index d8ce27ed9..614738429 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -134,7 +134,7 @@ func TestNewMCPServer_CreatesSuccessfully(t *testing.T) { require.NoError(t, err, "expected inventory build to succeed") // Create the server - server, err := NewMCPServer(&cfg, deps, inv) + server, err := NewMCPServer(t.Context(), &cfg, deps, inv) require.NoError(t, err, "expected server creation to succeed") require.NotNil(t, server, "expected server to be non-nil") diff --git a/pkg/http/handler.go b/pkg/http/handler.go index cfce791b2..bee065196 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -4,8 +4,8 @@ import ( "context" "log/slog" "net/http" - "strings" + ghcontext "github.com/github/github-mcp-server/pkg/context" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/headers" "github.com/github/github-mcp-server/pkg/http/middleware" @@ -16,9 +16,10 @@ import ( ) type InventoryFactoryFunc func(r *http.Request) (*inventory.Inventory, error) -type GitHubMCPServerFactoryFunc func(ctx context.Context, r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) +type GitHubMCPServerFactoryFunc func(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) type HTTPMcpHandler struct { + ctx context.Context config *HTTPServerConfig deps github.ToolDependencies logger *slog.Logger @@ -46,7 +47,9 @@ func WithInventoryFactory(f InventoryFactoryFunc) HTTPMcpHandlerOption { } } -func NewHTTPMcpHandler(cfg *HTTPServerConfig, +func NewHTTPMcpHandler( + ctx context.Context, + cfg *HTTPServerConfig, deps github.ToolDependencies, t translations.TranslationHelperFunc, logger *slog.Logger, @@ -67,6 +70,7 @@ func NewHTTPMcpHandler(cfg *HTTPServerConfig, } return &HTTPMcpHandler{ + ctx: ctx, config: cfg, deps: deps, logger: logger, @@ -76,8 +80,33 @@ func NewHTTPMcpHandler(cfg *HTTPServerConfig, } } +// RegisterRoutes registers the routes for the MCP server +// URL-based values take precedence over header-based values func (h *HTTPMcpHandler) RegisterRoutes(r chi.Router) { + r.Use(middleware.WithRequestConfig) + r.Mount("/", h) + // Mount readonly and toolset routes + r.With(withToolset).Mount("/x/{toolset}", h) + r.With(withReadonly, withToolset).Mount("/x/{toolset}/readonly", h) + r.With(withReadonly).Mount("/readonly", h) +} + +// withReadonly is middleware that sets readonly mode in the request context +func withReadonly(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := ghcontext.WithReadonly(r.Context(), true) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// withToolset is middleware that extracts the toolset from the URL and sets it in the request context +func withToolset(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolset := chi.URLParam(r, "toolset") + ctx := ghcontext.WithToolsets(r.Context(), []string{toolset}) + next.ServeHTTP(w, r.WithContext(ctx)) + }) } func (h *HTTPMcpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -87,7 +116,7 @@ func (h *HTTPMcpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ghServer, err := h.githubMcpServerFactory(r.Context(), r, h.deps, inventory, &github.MCPServerConfig{ + ghServer, err := h.githubMcpServerFactory(r, h.deps, inventory, &github.MCPServerConfig{ Version: h.config.Version, Translator: h.t, ContentWindowSize: h.config.ContentWindowSize, @@ -108,8 +137,8 @@ func (h *HTTPMcpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { middleware.ExtractUserToken()(mcpHandler).ServeHTTP(w, r) } -func DefaultGitHubMCPServerFactory(ctx context.Context, _ *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { - return github.NewMCPServer(&github.MCPServerConfig{ +func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) { + return github.NewMCPServer(r.Context(), &github.MCPServerConfig{ Version: cfg.Version, Translator: cfg.Translator, ContentWindowSize: cfg.ContentWindowSize, @@ -123,52 +152,37 @@ func DefaultInventoryFactory(cfg *HTTPServerConfig, t translations.TranslationHe b := github.NewInventory(t).WithDeprecatedAliases(github.DeprecatedToolAliases) // Feature checker composition - headerFeatures := parseCommaSeparatedHeader(r.Header.Get(headers.MCPFeaturesHeader)) + headerFeatures := headers.ParseCommaSeparated(r.Header.Get(headers.MCPFeaturesHeader)) if checker := ComposeFeatureChecker(headerFeatures, staticChecker); checker != nil { b = b.WithFeatureChecker(checker) } - b = InventoryFiltersForRequestHeaders(r, b) + b = InventoryFiltersForRequest(r, b) b.WithServerInstructions() return b.Build() } } -// InventoryFiltersForRequestHeaders applies inventory filters based on HTTP request headers. -// Whitespace is trimmed from comma-separated values; empty values are ignored. -func InventoryFiltersForRequestHeaders(r *http.Request, builder *inventory.Builder) *inventory.Builder { - if r.Header.Get(headers.MCPReadOnlyHeader) != "" { +// InventoryFiltersForRequest applies filters to the inventory builder +// based on the request context and headers +func InventoryFiltersForRequest(r *http.Request, builder *inventory.Builder) *inventory.Builder { + ctx := r.Context() + + if ghcontext.IsReadonly(ctx) { builder = builder.WithReadOnly(true) } - if toolsetsStr := r.Header.Get(headers.MCPToolsetsHeader); toolsetsStr != "" { - toolsets := parseCommaSeparatedHeader(toolsetsStr) + if toolsets := ghcontext.GetToolsets(ctx); len(toolsets) > 0 { builder = builder.WithToolsets(toolsets) } - if toolsStr := r.Header.Get(headers.MCPToolsHeader); toolsStr != "" { - tools := parseCommaSeparatedHeader(toolsStr) + if tools := ghcontext.GetTools(ctx); len(tools) > 0 { + if len(ghcontext.GetToolsets(ctx)) == 0 { + builder = builder.WithToolsets([]string{}) + } builder = builder.WithTools(github.CleanTools(tools)) } return builder } - -// parseCommaSeparatedHeader splits a header value by comma, trims whitespace, -// and filters out empty values. -func parseCommaSeparatedHeader(value string) []string { - if value == "" { - return []string{} - } - - parts := strings.Split(value, ",") - result := make([]string, 0, len(parts)) - for _, p := range parts { - trimmed := strings.TrimSpace(p) - if trimmed != "" { - result = append(result, trimmed) - } - } - return result -} diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go new file mode 100644 index 000000000..83a2438d7 --- /dev/null +++ b/pkg/http/handler_test.go @@ -0,0 +1,101 @@ +package http + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mockTool(name, toolsetID string, readOnly bool) inventory.ServerTool { + return inventory.ServerTool{ + Tool: mcp.Tool{ + Name: name, + Annotations: &mcp.ToolAnnotations{ReadOnlyHint: readOnly}, + }, + Toolset: inventory.ToolsetMetadata{ + ID: inventory.ToolsetID(toolsetID), + Description: "Test: " + toolsetID, + }, + } +} + +func TestInventoryFiltersForRequest(t *testing.T) { + tools := []inventory.ServerTool{ + mockTool("get_file_contents", "repos", true), + mockTool("create_repository", "repos", false), + mockTool("list_issues", "issues", true), + mockTool("issue_write", "issues", false), + } + + tests := []struct { + name string + contextSetup func(context.Context) context.Context + expectedTools []string + }{ + { + name: "no filters applies defaults", + contextSetup: func(ctx context.Context) context.Context { return ctx }, + expectedTools: []string{"get_file_contents", "create_repository", "list_issues", "issue_write"}, + }, + { + name: "readonly from context filters write tools", + contextSetup: func(ctx context.Context) context.Context { + return ghcontext.WithReadonly(ctx, true) + }, + expectedTools: []string{"get_file_contents", "list_issues"}, + }, + { + name: "toolset from context filters to toolset", + contextSetup: func(ctx context.Context) context.Context { + return ghcontext.WithToolsets(ctx, []string{"repos"}) + }, + expectedTools: []string{"get_file_contents", "create_repository"}, + }, + { + name: "tools alone clears default toolsets", + contextSetup: func(ctx context.Context) context.Context { + return ghcontext.WithTools(ctx, []string{"list_issues"}) + }, + expectedTools: []string{"list_issues"}, + }, + { + name: "tools are additive with toolsets", + contextSetup: func(ctx context.Context) context.Context { + ctx = ghcontext.WithToolsets(ctx, []string{"repos"}) + ctx = ghcontext.WithTools(ctx, []string{"list_issues"}) + return ctx + }, + expectedTools: []string{"get_file_contents", "create_repository", "list_issues"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req = req.WithContext(tt.contextSetup(req.Context())) + + builder := inventory.NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}) + + builder = InventoryFiltersForRequest(req, builder) + inv, err := builder.Build() + require.NoError(t, err) + + available := inv.AvailableTools(context.Background()) + toolNames := make([]string, len(available)) + for i, tool := range available { + toolNames[i] = tool.Tool.Name + } + + assert.ElementsMatch(t, tt.expectedTools, toolNames) + }) + } +} diff --git a/pkg/http/headers/parse.go b/pkg/http/headers/parse.go new file mode 100644 index 000000000..2b5eddacd --- /dev/null +++ b/pkg/http/headers/parse.go @@ -0,0 +1,21 @@ +package headers + +import "strings" + +// ParseCommaSeparated splits a header value by comma, trims whitespace, +// and filters out empty values +func ParseCommaSeparated(value string) []string { + if value == "" { + return []string{} + } + + parts := strings.Split(value, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + trimmed := strings.TrimSpace(p) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} diff --git a/pkg/http/headers/parse_test.go b/pkg/http/headers/parse_test.go new file mode 100644 index 000000000..d8b55a696 --- /dev/null +++ b/pkg/http/headers/parse_test.go @@ -0,0 +1,58 @@ +package headers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseCommaSeparated(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "empty string", + input: "", + expected: []string{}, + }, + { + name: "single value", + input: "foo", + expected: []string{"foo"}, + }, + { + name: "multiple values", + input: "foo,bar,baz", + expected: []string{"foo", "bar", "baz"}, + }, + { + name: "whitespace trimmed", + input: " foo , bar , baz ", + expected: []string{"foo", "bar", "baz"}, + }, + { + name: "empty values filtered", + input: "foo,,bar,", + expected: []string{"foo", "bar"}, + }, + { + name: "only commas", + input: ",,,", + expected: []string{}, + }, + { + name: "whitespace only values filtered", + input: "foo, ,bar", + expected: []string{"foo", "bar"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseCommaSeparated(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/http/middleware/request_config.go b/pkg/http/middleware/request_config.go new file mode 100644 index 000000000..0f7d3b2c7 --- /dev/null +++ b/pkg/http/middleware/request_config.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "net/http" + "slices" + "strings" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" +) + +// WithRequestConfig is a middleware that extracts MCP-related headers and sets them in the request context +func WithRequestConfig(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if relaxedParseBool(r.Header.Get(headers.MCPReadOnlyHeader)) { + ctx = ghcontext.WithReadonly(ctx, true) + } + + if toolsets := headers.ParseCommaSeparated(r.Header.Get(headers.MCPToolsetsHeader)); len(toolsets) > 0 { + ctx = ghcontext.WithToolsets(ctx, toolsets) + } + + if tools := headers.ParseCommaSeparated(r.Header.Get(headers.MCPToolsHeader)); len(tools) > 0 { + ctx = ghcontext.WithTools(ctx, tools) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// relaxedParseBool parses a string into a boolean value, treating various +// common false values or empty strings as false, and everything else as true. +// It is case-insensitive and trims whitespace. +func relaxedParseBool(s string) bool { + s = strings.TrimSpace(strings.ToLower(s)) + falseValues := []string{"", "false", "0", "no", "off", "n", "f"} + return !slices.Contains(falseValues, s) +} diff --git a/pkg/http/server.go b/pkg/http/server.go index c6054f727..c14ae9eee 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -98,7 +98,7 @@ func RunHTTPServer(cfg HTTPServerConfig) error { r := chi.NewRouter() - handler := NewHTTPMcpHandler(&cfg, deps, t, logger) + handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger) handler.RegisterRoutes(r) addr := fmt.Sprintf(":%d", cfg.Port)