Files
LocalAGI/core/agent/mcp.go
2026-01-31 09:19:26 +01:00

223 lines
6.4 KiB
Go

package agent
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"time"
"net/http"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/mudler/LocalAGI/core/types"
"github.com/mudler/xlog"
"github.com/sashabaranov/go-openai/jsonschema"
)
var _ types.Action = &mcpWrapperAction{}
type MCPServer struct {
URL string `json:"url"`
Token string `json:"token"`
}
type MCPSTDIOServer struct {
Args []string `json:"args"`
Env []string `json:"env"`
Cmd string `json:"cmd"`
}
type mcpWrapperAction struct {
mcpClient *mcp.ClientSession
inputSchema ToolInputSchema
toolName string
toolDescription string
}
func (m *mcpWrapperAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
// We don't call the method here, it is used by cogito.
// We will just use these to have a list of actions that MCP server provides for resolving internal states
return types.ActionResult{Result: "MCP action called"}, fmt.Errorf("not implemented")
}
func (m *mcpWrapperAction) Definition() types.ActionDefinition {
props := map[string]jsonschema.Definition{}
dat, err := json.Marshal(m.inputSchema.Properties)
if err != nil {
xlog.Error("Failed to marshal input schema", "error", err.Error())
}
json.Unmarshal(dat, &props)
return types.ActionDefinition{
Name: types.ActionDefinitionName(m.toolName),
Description: m.toolDescription,
Required: m.inputSchema.Required,
//Properties: ,
Properties: props,
}
}
type ToolInputSchema struct {
Type string `json:"type"`
Properties map[string]interface{} `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
func (a *Agent) addTools(client *mcp.ClientSession) (types.Actions, error) {
var generatedActions types.Actions
tools, err := client.ListTools(a.context, nil)
if err != nil {
xlog.Error("Failed to list tools", "error", err.Error())
return nil, err
}
for _, t := range tools.Tools {
desc := ""
if t.Description != "" {
desc = t.Description
}
xlog.Debug("Tool", "name", t.Name, "description", desc)
dat, err := json.Marshal(t.InputSchema)
if err != nil {
xlog.Error("Failed to marshal input schema", "error", err.Error())
}
xlog.Debug("Input schema", "tool", t.Name, "schema", string(dat))
// XXX: This is a wild guess, to verify (data types might be incompatible)
var inputSchema ToolInputSchema
err = json.Unmarshal(dat, &inputSchema)
if err != nil {
xlog.Error("Failed to unmarshal input schema", "error", err.Error())
}
// Create a new action with Client + tool
generatedActions = append(generatedActions, &mcpWrapperAction{
mcpClient: client,
toolName: t.Name,
inputSchema: inputSchema,
toolDescription: desc,
})
}
return generatedActions, nil
}
// bearerTokenRoundTripper is a custom roundtripper that injects a bearer token
// into HTTP requests
type bearerTokenRoundTripper struct {
token string
base http.RoundTripper
}
// RoundTrip implements the http.RoundTripper interface
func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if rt.token != "" {
req.Header.Set("Authorization", "Bearer "+rt.token)
}
return rt.base.RoundTrip(req)
}
// newBearerTokenRoundTripper creates a new roundtripper that injects the given token
func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper {
if base == nil {
base = http.DefaultTransport
}
return &bearerTokenRoundTripper{
token: token,
base: base,
}
}
func (a *Agent) initMCPActions() error {
a.closeMCPServers() // Make sure we stop all previous servers if any is active
a.mcpActionDefinitions = nil
var err error
generatedActions := types.Actions{}
client := mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
// Connect to a server over stdin/stdout.
// MCP HTTP Servers
for _, mcpServer := range a.options.mcpServers {
// Create HTTP client with custom roundtripper for bearer token injection
httpclient := &http.Client{
Timeout: 360 * time.Second,
Transport: newBearerTokenRoundTripper(mcpServer.Token, http.DefaultTransport),
}
streamableTransport := &mcp.StreamableClientTransport{HTTPClient: httpclient, Endpoint: mcpServer.URL}
session, err := client.Connect(a.context, streamableTransport, nil)
if err != nil {
xlog.Error("Failed to connect to MCP server via StreamableClientTransport", "server", mcpServer, "error", err.Error())
sseTransport := &mcp.SSEClientTransport{HTTPClient: httpclient, Endpoint: mcpServer.URL}
session, err = client.Connect(a.context, sseTransport, nil)
if err != nil {
xlog.Error("Failed to connect to MCP server via SSEClientTransport", "server", mcpServer, "error", err.Error())
continue
}
}
a.mcpSessions = append(a.mcpSessions, session)
xlog.Debug("Adding tools for MCP server", "server", mcpServer)
actions, err := a.addTools(session)
if err != nil {
xlog.Error("Failed to add tools for MCP server", "server", mcpServer, "error", err.Error())
}
generatedActions = append(generatedActions, actions...)
}
// MCP STDIO Servers
if a.options.mcpPrepareScript != "" {
xlog.Debug("Preparing MCP", "script", a.options.mcpPrepareScript)
prepareCmd := exec.Command("/bin/bash", "-c", a.options.mcpPrepareScript)
output, err := prepareCmd.CombinedOutput()
if err != nil {
xlog.Error("Failed with error: '%s' - %s", err.Error(), output)
}
xlog.Debug("Prepared MCP: \n%s", output)
}
for _, mcpStdioServer := range a.options.mcpStdioServers {
command := exec.Command(mcpStdioServer.Cmd, mcpStdioServer.Args...)
command.Env = os.Environ()
command.Env = append(command.Env, mcpStdioServer.Env...)
// Create a new client
session, err := client.Connect(a.context, &mcp.CommandTransport{
Command: command}, nil)
if err != nil {
xlog.Error("Failed to connect to MCP server", "server", mcpStdioServer, "error", err.Error())
continue
}
a.mcpSessions = append(a.mcpSessions, session)
xlog.Debug("Adding tools for MCP server (stdio)", "server", mcpStdioServer)
actions, err := a.addTools(session)
if err != nil {
xlog.Error("Failed to add tools for MCP server", "server", mcpStdioServer, "error", err.Error())
}
generatedActions = append(generatedActions, actions...)
}
a.mcpActionDefinitions = generatedActions
return err
}
func (a *Agent) closeMCPServers() {
for _, s := range a.mcpSessions {
s.Close()
}
}