第17章 自定义组件开发
AgentScope-Java采用高度模块化的架构设计,所有核心组件都基于接口定义,开发者可以通过实现这些接口来扩展框架功能。本章将详细介绍如何开发自定义的Model、Memory、Tool、Formatter和Hook组件。
17.1 组件扩展架构概览
17.1.1 可扩展组件体系
AgentScope-Java的核心组件都遵循接口-实现分离原则:
┌─────────────────────────────────────────────────────────────────┐
│ AgentScope 组件体系 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Model │ │ Memory │ │ Formatter │ │
│ │ (模型接口) │ │ (记忆接口) │ │ (格式化器) │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ ┌──────┴──────┐ ┌──────┴──────┐ ┌──────┴──────┐ │
│ │ChatModelBase│ │InMemoryMemory│ │AbstractBase │ │
│ │ (基类) │ │ (内存) │ │ Formatter │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ ┌──────┴──────┐ ┌──────┴──────┐ ┌──────┴──────┐ │
│ │OpenAIChatM..│ │LongTermMemory│ │OpenAIChat.. │ │
│ │DashScopeCh..│ │DatabaseMemory│ │DashScopeCh..│ │
│ │AnthropicCh..│ │RedisMemory │ │GeminiChatF..│ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Tool │ │ Hook │ │
│ │ (工具注解) │ │ (钩子接口) │ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ ┌──────┴──────┐ ┌──────┴──────┐ │
│ │@Tool方法 │ │LoggingHook │ │
│ │Toolkit容器 │ │RAGHook │ │
│ │ToolResult.. │ │CustomHook │ │
│ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
17.1.2 扩展原则
开发自定义组件时应遵循以下原则:
| 原则 | 说明 | 示例 |
|---|---|---|
| 接口优先 | 实现标准接口而非继承具体类 | 实现Memory接口 |
| 单一职责 | 每个组件专注于单一功能 | Memory只负责消息存储 |
| 无状态设计 | 组件实例应尽量无状态 | Tool方法不依赖成员变量 |
| 线程安全 | 组件需支持并发访问 | 使用CopyOnWriteArrayList |
| 响应式优先 | 返回Mono/Flux支持异步 | Hook返回Mono |
17.2 自定义ChatModel
17.2.1 Model接口体系
Model是AgentScope-Java中最核心的组件,负责与LLM进行通信:
// Model接口定义 - 所有模型的顶层接口
public interface Model {
/**
* 流式生成聊天响应
*
* @param messages AgentScope消息列表
* @param tools 可用工具的Schema列表(可选)
* @param options 生成选项(可选)
* @return 响应流
*/
Flux<ChatResponse> stream(
List<Msg> messages,
List<ToolSchema> tools,
GenerateOptions options
);
/**
* 获取模型名称
*/
String getModelName();
}
// ChatModelBase抽象基类 - 提供通用功能
public abstract class ChatModelBase implements Model {
@Override
public final Flux<ChatResponse> stream(
List<Msg> messages,
List<ToolSchema> tools,
GenerateOptions options) {
// 自动添加链路追踪
return TracerRegistry.get()
.callModel(this, messages, tools, options,
() -> doStream(messages, tools, options));
}
// 子类实现具体的流式生成逻辑
protected abstract Flux<ChatResponse> doStream(
List<Msg> messages,
List<ToolSchema> tools,
GenerateOptions options);
}
17.2.2 实现自定义ChatModel
假设我们要接入一个自定义的LLM服务:
import io.agentscope.core.model.*;
import io.agentscope.core.message.*;
import io.agentscope.core.formatter.Formatter;
import reactor.core.publisher.Flux;
/**
* 自定义ChatModel实现示例
* 接入企业内部的LLM服务
*/
public class EnterpriseChatModel extends ChatModelBase {
// 依赖组件
private final EnterpriseClient client; // HTTP客户端
private final EnterpriseFormatter formatter; // 消息格式化器
private final GenerateOptions configuredOptions; // 预配置选项
// 私有构造器,通过Builder创建
private EnterpriseChatModel(
EnterpriseClient client,
EnterpriseFormatter formatter,
GenerateOptions options) {
this.client = client;
this.formatter = formatter;
this.configuredOptions = options;
}
@Override
protected Flux<ChatResponse> doStream(
List<Msg> messages,
List<ToolSchema> tools,
GenerateOptions options) {
// 1. 合并配置选项(传入的options优先)
GenerateOptions effectiveOptions =
GenerateOptions.mergeOptions(options, configuredOptions);
// 2. 使用Formatter转换消息格式
List<EnterpriseMessage> formattedMessages =
formatter.format(messages);
// 3. 构建请求
EnterpriseRequest request = EnterpriseRequest.builder()
.model(effectiveOptions.getModelName())
.messages(formattedMessages)
.stream(true)
.build();
// 4. 应用工具配置
if (tools != null && !tools.isEmpty()) {
formatter.applyTools(request, tools);
}
// 5. 应用生成选项
formatter.applyOptions(request, effectiveOptions, null);
// 6. 发起流式请求
Instant startTime = Instant.now();
return client.streamCall(request)
.map(response -> formatter.parseResponse(response, startTime))
.filter(Objects::nonNull);
}
@Override
public String getModelName() {
return configuredOptions != null
? configuredOptions.getModelName()
: null;
}
// Builder模式
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String apiKey;
private String modelName;
private String baseUrl;
private Boolean stream;
private EnterpriseClient client;
private EnterpriseFormatter formatter;
public Builder apiKey(String apiKey) {
this.apiKey = apiKey;
return this;
}
public Builder modelName(String modelName) {
this.modelName = modelName;
return this;
}
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}
public Builder stream(boolean stream) {
this.stream = stream;
return this;
}
public Builder client(EnterpriseClient client) {
this.client = client;
return this;
}
public Builder formatter(EnterpriseFormatter formatter) {
this.formatter = formatter;
return this;
}
public EnterpriseChatModel build() {
Objects.requireNonNull(modelName, "modelName must be set");
Objects.requireNonNull(apiKey, "apiKey must be set");
// 构建配置选项
GenerateOptions options = GenerateOptions.builder()
.apiKey(apiKey)
.modelName(modelName)
.baseUrl(baseUrl)
.stream(stream != null ? stream : true)
.build();
// 使用默认组件或自定义组件
EnterpriseClient effectiveClient =
client != null ? client : new EnterpriseClient();
EnterpriseFormatter effectiveFormatter =
formatter != null ? formatter : new EnterpriseFormatter();
return new EnterpriseChatModel(
effectiveClient,
effectiveFormatter,
options);
}
}
}
17.2.3 使用自定义ChatModel
// 创建自定义模型
Model enterpriseModel = EnterpriseChatModel.builder()
.apiKey("your-api-key")
.modelName("enterprise-llm-v2")
.baseUrl("https://internal-llm.company.com/api")
.stream(true)
.build();
// 创建Agent使用自定义模型
ReActAgent agent = ReActAgent.builder()
.name("EnterpriseAssistant")
.model(enterpriseModel)
.systemPrompt("你是企业智能助手")
.build();
// 正常使用
Msg response = agent.run(Msg.user("分析上季度销售数据")).block();
17.3 自定义Memory
17.3.1 Memory接口定义
Memory组件负责存储和管理Agent的对话历史:
/**
* Memory接口 - 对话记忆管理
* 继承StateModule以支持会话持久化
*/
public interface Memory extends StateModule {
/**
* 添加消息到记忆
*/
void addMessage(Msg message);
/**
* 获取所有存储的消息
*/
List<Msg> getMessages();
/**
* 删除指定索引的消息
* 索引越界时静默忽略
*/
void deleteMessage(int index);
/**
* 清空所有消息
*/
void clear();
}
/**
* StateModule接口 - 会话状态管理
*/
public interface StateModule {
/**
* 保存状态到会话
*/
void saveTo(Session session, SessionKey sessionKey);
/**
* 从会话加载状态
*/
void loadFrom(Session session, SessionKey sessionKey);
}
17.3.2 实现滑动窗口Memory
import io.agentscope.core.memory.Memory;
import io.agentscope.core.message.Msg;
import io.agentscope.core.message.MsgRole;
import io.agentscope.core.session.Session;
import io.agentscope.core.state.SessionKey;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
/**
* 滑动窗口Memory实现
* 只保留最近N轮对话,自动丢弃旧消息
*/
public class SlidingWindowMemory implements Memory {
// 线程安全的消息存储
private final List<Msg> messages = new CopyOnWriteArrayList<>();
// 窗口大小(保留的消息数量)
private final int windowSize;
// 是否保留系统消息(系统消息不计入窗口)
private final boolean preserveSystemMessages;
// 存储键前缀
private static final String KEY_PREFIX = "sliding_window_memory";
/**
* 构造函数
*
* @param windowSize 保留的最近消息数量
* @param preserveSystemMessages 是否保留系统消息
*/
public SlidingWindowMemory(int windowSize, boolean preserveSystemMessages) {
if (windowSize < 1) {
throw new IllegalArgumentException("windowSize must be >= 1");
}
this.windowSize = windowSize;
this.preserveSystemMessages = preserveSystemMessages;
}
/**
* 便捷构造函数(默认保留系统消息)
*/
public SlidingWindowMemory(int windowSize) {
this(windowSize, true);
}
@Override
public void addMessage(Msg message) {
messages.add(message);
// 触发窗口滑动
trimToWindow();
}
/**
* 修剪消息列表以适应窗口大小
*/
private synchronized void trimToWindow() {
if (preserveSystemMessages) {
// 分离系统消息和非系统消息
List<Msg> systemMessages = new ArrayList<>();
List<Msg> nonSystemMessages = new ArrayList<>();
for (Msg msg : messages) {
if (msg.getRole() == MsgRole.SYSTEM) {
systemMessages.add(msg);
} else {
nonSystemMessages.add(msg);
}
}
// 只对非系统消息应用窗口限制
if (nonSystemMessages.size() > windowSize) {
int excess = nonSystemMessages.size() - windowSize;
nonSystemMessages = nonSystemMessages.subList(
excess,
nonSystemMessages.size()
);
}
// 重组消息列表(系统消息在前)
messages.clear();
messages.addAll(systemMessages);
messages.addAll(nonSystemMessages);
} else {
// 直接对所有消息应用窗口限制
while (messages.size() > windowSize) {
messages.remove(0);
}
}
}
@Override
public List<Msg> getMessages() {
return new ArrayList<>(messages);
}
@Override
public void deleteMessage(int index) {
if (index >= 0 && index < messages.size()) {
messages.remove(index);
}
}
@Override
public void clear() {
messages.clear();
}
// ==================== StateModule实现 ====================
@Override
public void saveTo(Session session, SessionKey sessionKey) {
// 保存消息列表
session.save(sessionKey, KEY_PREFIX + "_messages",
new ArrayList<>(messages));
// 保存窗口大小配置
session.save(sessionKey, KEY_PREFIX + "_windowSize", windowSize);
}
@Override
public void loadFrom(Session session, SessionKey sessionKey) {
List<Msg> loaded = session.getList(
sessionKey,
KEY_PREFIX + "_messages",
Msg.class
);
messages.clear();
messages.addAll(loaded);
// 加载后重新应用窗口限制
trimToWindow();
}
// ==================== 辅助方法 ====================
/**
* 获取当前窗口大小
*/
public int getWindowSize() {
return windowSize;
}
/**
* 获取当前消息数量
*/
public int size() {
return messages.size();
}
/**
* 获取最近N条消息
*/
public List<Msg> getRecentMessages(int n) {
int start = Math.max(0, messages.size() - n);
return new ArrayList<>(messages.subList(start, messages.size()));
}
}
17.3.3 实现Token限制Memory
import io.agentscope.core.memory.Memory;
import io.agentscope.core.message.*;
import java.util.*;
/**
* Token限制Memory实现
* 根据Token数量限制记忆容量,避免超出模型上下文窗口
*/
public class TokenLimitedMemory implements Memory {
private final List<Msg> messages = new CopyOnWriteArrayList<>();
private final int maxTokens;
private final TokenCounter tokenCounter;
private int currentTokenCount = 0;
/**
* Token计数器接口
*/
public interface TokenCounter {
int countTokens(Msg message);
}
/**
* 简单的Token估算器(基于字符数)
*/
public static class SimpleTokenCounter implements TokenCounter {
private final double charsPerToken;
public SimpleTokenCounter(double charsPerToken) {
this.charsPerToken = charsPerToken;
}
@Override
public int countTokens(Msg message) {
String text = message.getContent().stream()
.filter(b -> b instanceof TextBlock)
.map(b -> ((TextBlock) b).getText())
.reduce("", (a, b) -> a + b);
return (int) Math.ceil(text.length() / charsPerToken);
}
}
public TokenLimitedMemory(int maxTokens, TokenCounter tokenCounter) {
this.maxTokens = maxTokens;
this.tokenCounter = tokenCounter;
}
/**
* 使用默认Token计数器(假设4字符≈1 Token)
*/
public TokenLimitedMemory(int maxTokens) {
this(maxTokens, new SimpleTokenCounter(4.0));
}
@Override
public synchronized void addMessage(Msg message) {
int newTokens = tokenCounter.countTokens(message);
// 移除旧消息直到有足够空间
while (currentTokenCount + newTokens > maxTokens && !messages.isEmpty()) {
// 跳过系统消息
int removeIndex = findFirstNonSystemMessage();
if (removeIndex < 0) break;
Msg removed = messages.remove(removeIndex);
currentTokenCount -= tokenCounter.countTokens(removed);
}
// 添加新消息
messages.add(message);
currentTokenCount += newTokens;
}
private int findFirstNonSystemMessage() {
for (int i = 0; i < messages.size(); i++) {
if (messages.get(i).getRole() != MsgRole.SYSTEM) {
return i;
}
}
return -1;
}
@Override
public List<Msg> getMessages() {
return new ArrayList<>(messages);
}
@Override
public void deleteMessage(int index) {
if (index >= 0 && index < messages.size()) {
Msg removed = messages.remove(index);
currentTokenCount -= tokenCounter.countTokens(removed);
}
}
@Override
public void clear() {
messages.clear();
currentTokenCount = 0;
}
@Override
public void saveTo(Session session, SessionKey sessionKey) {
session.save(sessionKey, "token_memory_messages",
new ArrayList<>(messages));
}
@Override
public void loadFrom(Session session, SessionKey sessionKey) {
List<Msg> loaded = session.getList(
sessionKey, "token_memory_messages", Msg.class);
messages.clear();
currentTokenCount = 0;
for (Msg msg : loaded) {
messages.add(msg);
currentTokenCount += tokenCounter.countTokens(msg);
}
}
/**
* 获取当前Token使用量
*/
public int getCurrentTokenCount() {
return currentTokenCount;
}
/**
* 获取剩余Token容量
*/
public int getRemainingTokens() {
return maxTokens - currentTokenCount;
}
}
17.3.4 使用自定义Memory
// 使用滑动窗口Memory
Memory slidingMemory = new SlidingWindowMemory(20, true);
// 使用Token限制Memory
Memory tokenMemory = new TokenLimitedMemory(4000);
// 创建Agent时指定Memory
ReActAgent agent = ReActAgent.builder()
.name("MemoryAgent")
.model(model)
.memory(slidingMemory) // 或 tokenMemory
.systemPrompt("你是一个智能助手")
.build();
17.4 自定义Tool
17.4.1 Tool注解系统
AgentScope-Java使用注解来定义工具:
/**
* @Tool注解 - 标记方法为可被Agent调用的工具
*/
@Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Tool {
// 工具名称(默认使用方法名)
String name() default "";
// 工具描述(告诉LLM何时使用此工具)
String description() default "";
// 结果转换器(自定义结果格式化)
Class<? extends ToolResultConverter> converter()
default DefaultToolResultConverter.class;
}
/**
* @ToolParam注解 - 描述工具参数
*/
@Target({ElementType.PARAMETER, ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ToolParam {
// 参数名称(必需,因为Java不保留参数名)
String name();
// 是否必需
boolean required() default true;
// 参数描述
String description() default "";
}
17.4.2 创建基础工具
import io.agentscope.core.tool.*;
/**
* 数据分析工具集
*/
public class DataAnalysisTools {
private final DataService dataService;
public DataAnalysisTools(DataService dataService) {
this.dataService = dataService;
}
/**
* 查询数据库
*/
@Tool(
name = "query_database",
description = "Execute SQL query on the database. Use this when you need to retrieve or analyze data from the database."
)
public String queryDatabase(
@ToolParam(name = "sql", description = "SQL query to execute")
String sql,
@ToolParam(name = "limit", description = "Maximum number of rows", required = false)
Integer limit
) {
// 参数验证
if (sql == null || sql.trim().isEmpty()) {
return "Error: SQL query cannot be empty";
}
// 安全检查
if (containsDangerousKeywords(sql)) {
return "Error: Query contains forbidden keywords";
}
// 执行查询
try {
int effectiveLimit = limit != null ? limit : 100;
QueryResult result = dataService.executeQuery(sql, effectiveLimit);
return formatQueryResult(result);
} catch (Exception e) {
return "Error executing query: " + e.getMessage();
}
}
/**
* 生成报表
*/
@Tool(
name = "generate_report",
description = "Generate a statistical report for specified metrics and time range"
)
public String generateReport(
@ToolParam(name = "metrics", description = "Comma-separated list of metrics to include")
String metrics,
@ToolParam(name = "start_date", description = "Start date in YYYY-MM-DD format")
String startDate,
@ToolParam(name = "end_date", description = "End date in YYYY-MM-DD format")
String endDate,
@ToolParam(name = "format", description = "Output format: json, csv, or markdown", required = false)
String format
) {
String outputFormat = format != null ? format : "markdown";
Report report = dataService.generateReport(
Arrays.asList(metrics.split(",")),
LocalDate.parse(startDate),
LocalDate.parse(endDate)
);
return switch (outputFormat.toLowerCase()) {
case "json" -> report.toJson();
case "csv" -> report.toCsv();
default -> report.toMarkdown();
};
}
private boolean containsDangerousKeywords(String sql) {
String upper = sql.toUpperCase();
return upper.contains("DROP ") ||
upper.contains("DELETE ") ||
upper.contains("TRUNCATE ");
}
private String formatQueryResult(QueryResult result) {
StringBuilder sb = new StringBuilder();
sb.append("Query returned ").append(result.getRowCount()).append(" rows:\n\n");
sb.append(result.toMarkdownTable());
return sb.toString();
}
}
17.4.3 创建流式输出工具
使用ToolEmitter支持工具的流式输出:
import io.agentscope.core.tool.*;
/**
* 支持流式输出的工具
*/
public class StreamingTools {
/**
* 实时股票监控(流式输出价格变动)
*/
@Tool(
name = "monitor_stock",
description = "Monitor stock price in real-time and report changes"
)
public String monitorStock(
@ToolParam(name = "symbol", description = "Stock symbol (e.g., AAPL)")
String symbol,
@ToolParam(name = "duration_seconds", description = "Monitoring duration in seconds")
int durationSeconds,
// ToolEmitter是框架自动注入的,不需要@ToolParam
ToolEmitter emitter
) {
StockService stockService = new StockService();
long endTime = System.currentTimeMillis() + durationSeconds * 1000;
Double lastPrice = null;
while (System.currentTimeMillis() < endTime) {
try {
Double currentPrice = stockService.getPrice(symbol);
if (lastPrice != null && !currentPrice.equals(lastPrice)) {
// 通过emitter发送实时更新
String change = currentPrice > lastPrice ? "+" : "";
double diff = currentPrice - lastPrice;
emitter.emit(String.format(
"[%s] %s: $%.2f (%s%.2f)\n",
LocalTime.now().format(DateTimeFormatter.ofPattern("HH:mm:ss")),
symbol,
currentPrice,
change,
diff
));
}
lastPrice = currentPrice;
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
return String.format("Monitoring complete. Final price of %s: $%.2f",
symbol, lastPrice);
}
/**
* 长文本生成(分段流式输出)
*/
@Tool(
name = "generate_document",
description = "Generate a long document with multiple sections"
)
public String generateDocument(
@ToolParam(name = "topic", description = "Document topic")
String topic,
@ToolParam(name = "sections", description = "Number of sections to generate")
int sections,
ToolEmitter emitter
) {
StringBuilder fullDocument = new StringBuilder();
for (int i = 1; i <= sections; i++) {
// 生成章节内容
String sectionContent = generateSection(topic, i);
fullDocument.append(sectionContent).append("\n\n");
// 流式输出进度
emitter.emit(String.format("Section %d/%d completed...\n", i, sections));
}
return fullDocument.toString();
}
private String generateSection(String topic, int sectionNum) {
return "## Section " + sectionNum + ": " + topic + "\n\n" +
"Content for section " + sectionNum + "...";
}
}
17.4.4 自定义ToolResultConverter
import io.agentscope.core.tool.*;
import io.agentscope.core.message.*;
import java.lang.reflect.Type;
/**
* 自定义结果转换器
* 将工具结果转换为特定格式
*/
public class SensitiveDataConverter implements ToolResultConverter {
// 需要脱敏的字段模式
private final List<Pattern> sensitivePatterns = Arrays.asList(
Pattern.compile("\\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}\\b"), // Email
Pattern.compile("\\b\\d{3}[-.]?\\d{3}[-.]?\\d{4}\\b"), // Phone
Pattern.compile("\\b\\d{4}[- ]?\\d{4}[- ]?\\d{4}[- ]?\\d{4}\\b") // Card number
);
@Override
public ToolResultBlock convert(Object result, Type returnType) {
if (result == null) {
return ToolResultBlock.of("No result");
}
// 转换为字符串
String resultStr = result.toString();
// 脱敏处理
String sanitized = sanitize(resultStr);
// 返回ToolResultBlock
return ToolResultBlock.of(sanitized);
}
private String sanitize(String input) {
String result = input;
for (Pattern pattern : sensitivePatterns) {
result = pattern.matcher(result).replaceAll("***REDACTED***");
}
return result;
}
}
// 使用自定义转换器
public class CustomerTools {
@Tool(
name = "get_customer_info",
description = "Get customer information by ID",
converter = SensitiveDataConverter.class // 使用自定义转换器
)
public CustomerInfo getCustomerInfo(
@ToolParam(name = "customer_id") String customerId
) {
return customerService.findById(customerId);
}
}
17.4.5 注册和使用工具
// 创建工具实例
DataAnalysisTools analysisTools = new DataAnalysisTools(dataService);
StreamingTools streamingTools = new StreamingTools();
// 方式1:直接传递给Agent
ReActAgent agent = ReActAgent.builder()
.name("ToolAgent")
.model(model)
.tools(analysisTools, streamingTools) // 传递工具对象
.build();
// 方式2:使用Toolkit容器
Toolkit toolkit = Toolkit.builder()
.addTool(analysisTools)
.addTool(streamingTools)
.build();
ReActAgent agent2 = ReActAgent.builder()
.name("ToolAgent2")
.model(model)
.toolkit(toolkit)
.build();
17.5 自定义Formatter
17.5.1 Formatter接口
Formatter负责在AgentScope消息格式和LLM提供商特定格式之间转换:
/**
* Formatter接口
*
* @param <TReq> 提供商请求消息类型
* @param <TResp> 提供商响应类型
* @param <TParams> 提供商请求参数构建器类型
*/
public interface Formatter<TReq, TResp, TParams> {
/**
* 将AgentScope消息转换为提供商格式
*/
List<TReq> format(List<Msg> msgs);
/**
* 解析提供商响应为ChatResponse
*/
ChatResponse parseResponse(TResp response, Instant startTime);
/**
* 应用生成选项
*/
void applyOptions(
TParams paramsBuilder,
GenerateOptions options,
GenerateOptions defaultOptions
);
/**
* 应用工具Schema
*/
void applyTools(TParams paramsBuilder, List<ToolSchema> tools);
/**
* 应用工具选择配置
*/
default void applyToolChoice(TParams paramsBuilder, ToolChoice toolChoice) {
// 默认空实现
}
}
17.5.2 实现自定义Formatter
import io.agentscope.core.formatter.*;
import io.agentscope.core.message.*;
import io.agentscope.core.model.*;
/**
* 自定义企业LLM的Formatter
*/
public class EnterpriseFormatter
extends AbstractBaseFormatter<EnterpriseMessage, EnterpriseResponse, EnterpriseRequest> {
@Override
protected List<EnterpriseMessage> doFormat(List<Msg> msgs) {
List<EnterpriseMessage> result = new ArrayList<>();
for (Msg msg : msgs) {
// 转换角色
String role = convertRole(msg.getRole());
// 提取文本内容
String content = extractTextContent(msg);
// 处理多模态内容
List<EnterpriseContent> contents = new ArrayList<>();
if (hasMediaContent(msg)) {
// 处理图片、音频等
for (ContentBlock block : msg.getContent()) {
if (block instanceof TextBlock tb) {
contents.add(EnterpriseContent.text(tb.getText()));
} else if (block instanceof ImageBlock ib) {
contents.add(convertImage(ib));
}
}
} else {
contents.add(EnterpriseContent.text(content));
}
// 处理工具调用
if (msg.getRole() == MsgRole.ASSISTANT) {
List<ToolUseBlock> toolCalls = extractToolCalls(msg);
if (!toolCalls.isEmpty()) {
result.add(EnterpriseMessage.builder()
.role(role)
.contents(contents)
.toolCalls(convertToolCalls(toolCalls))
.build());
continue;
}
}
// 处理工具结果
if (msg.getRole() == MsgRole.TOOL) {
for (ContentBlock block : msg.getContent()) {
if (block instanceof ToolResultBlock trb) {
result.add(EnterpriseMessage.builder()
.role("tool")
.toolCallId(trb.getToolUseId())
.contents(List.of(EnterpriseContent.text(
convertToolResultToString(trb.getOutput()))))
.build());
}
}
continue;
}
result.add(EnterpriseMessage.builder()
.role(role)
.contents(contents)
.build());
}
return result;
}
@Override
public ChatResponse parseResponse(EnterpriseResponse response, Instant startTime) {
// 构建ChatResponse
ChatResponse.Builder builder = ChatResponse.builder()
.model(response.getModel())
.duration(Duration.between(startTime, Instant.now()));
// 解析内容
List<ContentBlock> blocks = new ArrayList<>();
for (EnterpriseContent content : response.getContents()) {
if (content.isText()) {
blocks.add(new TextBlock(content.getText()));
} else if (content.isToolCall()) {
blocks.add(new ToolUseBlock(
content.getToolCallId(),
content.getToolName(),
content.getToolArguments()
));
}
}
builder.content(blocks);
// 解析Token使用情况
if (response.getUsage() != null) {
builder.usage(TokenUsage.builder()
.inputTokens(response.getUsage().getPromptTokens())
.outputTokens(response.getUsage().getCompletionTokens())
.build());
}
// 设置完成原因
builder.finishReason(mapFinishReason(response.getFinishReason()));
return builder.build();
}
@Override
public void applyOptions(
EnterpriseRequest request,
GenerateOptions options,
GenerateOptions defaultOptions) {
// 应用温度
Double temperature = getOptionOrDefault(
options, defaultOptions, GenerateOptions::getTemperature);
if (temperature != null) {
request.setTemperature(temperature);
}
// 应用最大Token数
Integer maxTokens = getOptionOrDefault(
options, defaultOptions, GenerateOptions::getMaxTokens);
if (maxTokens != null) {
request.setMaxTokens(maxTokens);
}
// 应用停止词
List<String> stopSequences = getOptionOrDefault(
options, defaultOptions, GenerateOptions::getStopSequences);
if (stopSequences != null && !stopSequences.isEmpty()) {
request.setStop(stopSequences);
}
}
@Override
public void applyTools(EnterpriseRequest request, List<ToolSchema> tools) {
if (tools == null || tools.isEmpty()) {
return;
}
List<EnterpriseTool> enterpriseTools = tools.stream()
.map(this::convertToolSchema)
.collect(Collectors.toList());
request.setTools(enterpriseTools);
}
// ==================== 辅助方法 ====================
private String convertRole(MsgRole role) {
return switch (role) {
case USER -> "user";
case ASSISTANT -> "assistant";
case SYSTEM -> "system";
case TOOL -> "tool";
};
}
private EnterpriseContent convertImage(ImageBlock ib) {
Source source = ib.getSource();
if (source instanceof URLSource urlSource) {
return EnterpriseContent.imageUrl(urlSource.getUrl());
} else if (source instanceof Base64Source b64Source) {
return EnterpriseContent.imageBase64(
b64Source.getMediaType(),
b64Source.getData()
);
}
throw new IllegalArgumentException("Unsupported image source type");
}
private List<ToolUseBlock> extractToolCalls(Msg msg) {
return msg.getContent().stream()
.filter(b -> b instanceof ToolUseBlock)
.map(b -> (ToolUseBlock) b)
.collect(Collectors.toList());
}
private FinishReason mapFinishReason(String reason) {
if (reason == null) return FinishReason.UNKNOWN;
return switch (reason.toLowerCase()) {
case "stop" -> FinishReason.STOP;
case "tool_calls" -> FinishReason.TOOL_CALLS;
case "length" -> FinishReason.MAX_TOKENS;
default -> FinishReason.UNKNOWN;
};
}
private EnterpriseTool convertToolSchema(ToolSchema schema) {
return EnterpriseTool.builder()
.type("function")
.function(EnterpriseFunction.builder()
.name(schema.getName())
.description(schema.getDescription())
.parameters(schema.getParameters())
.build())
.build();
}
}
17.5.3 实现MultiAgent Formatter
为多智能体场景实现Formatter,支持对话历史合并:
/**
* 多智能体Formatter
* 将多个Agent的对话合并为单一上下文
*/
public class EnterpriseMultiAgentFormatter extends EnterpriseFormatter {
@Override
protected List<EnterpriseMessage> doFormat(List<Msg> msgs) {
List<EnterpriseMessage> result = new ArrayList<>();
StringBuilder historyBuilder = new StringBuilder();
String currentAgentName = null;
for (Msg msg : msgs) {
// 检查是否应该跳过历史合并
if (shouldBypassHistory(msg)) {
// 先输出累积的历史
if (historyBuilder.length() > 0) {
result.add(createHistoryMessage(historyBuilder.toString()));
historyBuilder.setLength(0);
}
// 单独输出此消息
result.addAll(super.doFormat(List.of(msg)));
continue;
}
// 获取Agent名称
String agentName = msg.getName();
if (agentName == null) {
agentName = formatRoleLabel(msg.getRole());
}
// 系统消息特殊处理
if (msg.getRole() == MsgRole.SYSTEM) {
result.add(EnterpriseMessage.builder()
.role("system")
.contents(List.of(EnterpriseContent.text(
extractTextContent(msg))))
.build());
continue;
}
// 累积对话历史
String content = extractTextContent(msg);
if (!content.isEmpty()) {
historyBuilder.append(String.format("[%s]: %s\n\n",
agentName, content));
}
currentAgentName = agentName;
}
// 输出剩余的历史
if (historyBuilder.length() > 0) {
result.add(EnterpriseMessage.builder()
.role("user")
.contents(List.of(EnterpriseContent.text(
historyBuilder.toString().trim())))
.build());
}
return result;
}
private EnterpriseMessage createHistoryMessage(String history) {
return EnterpriseMessage.builder()
.role("user")
.contents(List.of(EnterpriseContent.text(history.trim())))
.build();
}
}
17.6 自定义Hook
17.6.1 Hook接口回顾
/**
* Hook接口
* 用于监控和拦截Agent执行
*/
public interface Hook {
/**
* 处理Hook事件
* 使用模式匹配处理不同事件类型
*/
<T extends HookEvent> Mono<T> onEvent(T event);
/**
* Hook优先级(数值越小优先级越高)
*/
default int priority() {
return 100;
}
}
17.6.2 实现审计日志Hook
import io.agentscope.core.hook.*;
import reactor.core.publisher.Mono;
import java.time.Instant;
/**
* 审计日志Hook
* 记录所有Agent操作用于合规审计
*/
public class AuditLoggingHook implements Hook {
private final AuditLogger auditLogger;
private final String applicationId;
public AuditLoggingHook(AuditLogger auditLogger, String applicationId) {
this.auditLogger = auditLogger;
this.applicationId = applicationId;
}
@Override
public int priority() {
return 50; // 较高优先级,确保审计记录完整
}
@Override
public <T extends HookEvent> Mono<T> onEvent(T event) {
return switch (event) {
case PreCallEvent e -> {
// 记录Agent调用开始
AuditRecord record = AuditRecord.builder()
.timestamp(Instant.now())
.applicationId(applicationId)
.agentName(e.getAgentName())
.eventType("AGENT_CALL_START")
.inputMessage(summarize(e.getInputMessage()))
.build();
auditLogger.log(record);
yield Mono.just(event);
}
case PreActingEvent e -> {
// 记录工具调用
ToolUseBlock toolUse = e.getToolUse();
AuditRecord record = AuditRecord.builder()
.timestamp(Instant.now())
.applicationId(applicationId)
.agentName(e.getAgentName())
.eventType("TOOL_INVOCATION")
.toolName(toolUse.getName())
.toolArguments(sanitizeArguments(toolUse.getArguments()))
.build();
auditLogger.log(record);
yield Mono.just(event);
}
case PostActingEvent e -> {
// 记录工具执行结果
AuditRecord record = AuditRecord.builder()
.timestamp(Instant.now())
.applicationId(applicationId)
.agentName(e.getAgentName())
.eventType("TOOL_RESULT")
.toolCallId(e.getToolResult().getToolUseId())
.resultSummary(summarize(e.getToolResult()))
.build();
auditLogger.log(record);
yield Mono.just(event);
}
case PostCallEvent e -> {
// 记录Agent调用完成
AuditRecord record = AuditRecord.builder()
.timestamp(Instant.now())
.applicationId(applicationId)
.agentName(e.getAgentName())
.eventType("AGENT_CALL_COMPLETE")
.outputSummary(summarize(e.getOutput()))
.tokenUsage(extractTokenUsage(e))
.build();
auditLogger.log(record);
yield Mono.just(event);
}
case ErrorEvent e -> {
// 记录错误
AuditRecord record = AuditRecord.builder()
.timestamp(Instant.now())
.applicationId(applicationId)
.agentName(e.getAgentName())
.eventType("ERROR")
.errorMessage(e.getError().getMessage())
.errorType(e.getError().getClass().getSimpleName())
.build();
auditLogger.log(record);
yield Mono.just(event);
}
default -> Mono.just(event);
};
}
private String summarize(Object obj) {
if (obj == null) return null;
String str = obj.toString();
return str.length() > 500 ? str.substring(0, 500) + "..." : str;
}
private Map<String, Object> sanitizeArguments(Map<String, Object> args) {
// 移除敏感字段
Map<String, Object> sanitized = new HashMap<>(args);
sanitized.remove("password");
sanitized.remove("token");
sanitized.remove("secret");
return sanitized;
}
private TokenUsageSummary extractTokenUsage(PostCallEvent e) {
// 从事件中提取Token使用情况
return TokenUsageSummary.builder()
.inputTokens(e.getInputTokens())
.outputTokens(e.getOutputTokens())
.build();
}
}
17.6.3 实现速率限制Hook
import io.agentscope.core.hook.*;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 速率限制Hook
* 控制Agent调用频率
*/
public class RateLimitingHook implements Hook {
private final int maxRequestsPerMinute;
private final ConcurrentHashMap<String, RateLimitBucket> buckets;
public RateLimitingHook(int maxRequestsPerMinute) {
this.maxRequestsPerMinute = maxRequestsPerMinute;
this.buckets = new ConcurrentHashMap<>();
}
@Override
public int priority() {
return 10; // 最高优先级,尽早拦截
}
@Override
public <T extends HookEvent> Mono<T> onEvent(T event) {
if (event instanceof PreCallEvent e) {
String agentName = e.getAgentName();
RateLimitBucket bucket = buckets.computeIfAbsent(
agentName,
k -> new RateLimitBucket(maxRequestsPerMinute)
);
if (!bucket.tryAcquire()) {
return Mono.error(new RateLimitExceededException(
"Rate limit exceeded for agent: " + agentName +
". Max " + maxRequestsPerMinute + " requests per minute."
));
}
}
return Mono.just(event);
}
/**
* 令牌桶实现
*/
private static class RateLimitBucket {
private final int maxTokens;
private final AtomicInteger tokens;
private volatile long lastRefillTime;
RateLimitBucket(int maxTokens) {
this.maxTokens = maxTokens;
this.tokens = new AtomicInteger(maxTokens);
this.lastRefillTime = System.currentTimeMillis();
}
synchronized boolean tryAcquire() {
refillIfNeeded();
if (tokens.get() > 0) {
tokens.decrementAndGet();
return true;
}
return false;
}
private void refillIfNeeded() {
long now = System.currentTimeMillis();
long elapsed = now - lastRefillTime;
// 每分钟重新填充
if (elapsed >= 60000) {
tokens.set(maxTokens);
lastRefillTime = now;
}
}
}
}
/**
* 速率限制异常
*/
public class RateLimitExceededException extends RuntimeException {
public RateLimitExceededException(String message) {
super(message);
}
}
17.6.4 实现内容过滤Hook
import io.agentscope.core.hook.*;
import io.agentscope.core.message.*;
import reactor.core.publisher.Mono;
import java.util.regex.Pattern;
/**
* 内容过滤Hook
* 过滤敏感内容和有害输出
*/
public class ContentFilterHook implements Hook {
private final ContentModerator moderator;
private final List<Pattern> blockedPatterns;
private final boolean blockOnViolation;
public ContentFilterHook(
ContentModerator moderator,
List<String> blockedPatterns,
boolean blockOnViolation) {
this.moderator = moderator;
this.blockedPatterns = blockedPatterns.stream()
.map(Pattern::compile)
.collect(Collectors.toList());
this.blockOnViolation = blockOnViolation;
}
@Override
public int priority() {
return 20; // 高优先级
}
@Override
public <T extends HookEvent> Mono<T> onEvent(T event) {
return switch (event) {
case PreReasoningEvent e -> {
// 检查输入内容
List<Msg> messages = e.getInputMessages();
for (Msg msg : messages) {
String content = extractText(msg);
ModerationResult result = moderator.check(content);
if (result.isFlagged()) {
if (blockOnViolation) {
yield Mono.error(new ContentViolationException(
"Input content violates policy: " + result.getReason()));
}
// 记录但不阻止
log.warn("Content flagged: {}", result.getReason());
}
}
yield Mono.just(event);
}
case PostReasoningEvent e -> {
// 检查输出内容
ChatResponse response = e.getResponse();
String outputText = extractOutputText(response);
// 检查正则模式
for (Pattern pattern : blockedPatterns) {
if (pattern.matcher(outputText).find()) {
if (blockOnViolation) {
yield Mono.error(new ContentViolationException(
"Output matches blocked pattern"));
}
// 替换敏感内容
outputText = pattern.matcher(outputText)
.replaceAll("[REDACTED]");
}
}
// 使用内容审核API
ModerationResult result = moderator.check(outputText);
if (result.isFlagged()) {
if (blockOnViolation) {
yield Mono.error(new ContentViolationException(
"Output content violates policy: " + result.getReason()));
}
// 修改输出
e.setResponse(createSafeResponse(response, result));
}
yield Mono.just(event);
}
case PostActingEvent e -> {
// 检查工具输出
ToolResultBlock toolResult = e.getToolResult();
String resultText = extractToolResultText(toolResult);
ModerationResult result = moderator.check(resultText);
if (result.isFlagged() && blockOnViolation) {
yield Mono.error(new ContentViolationException(
"Tool output violates policy"));
}
yield Mono.just(event);
}
default -> Mono.just(event);
};
}
private String extractText(Msg msg) {
return msg.getContent().stream()
.filter(b -> b instanceof TextBlock)
.map(b -> ((TextBlock) b).getText())
.collect(Collectors.joining(" "));
}
private String extractOutputText(ChatResponse response) {
return response.getContent().stream()
.filter(b -> b instanceof TextBlock)
.map(b -> ((TextBlock) b).getText())
.collect(Collectors.joining(" "));
}
private String extractToolResultText(ToolResultBlock result) {
return result.getOutput().stream()
.filter(b -> b instanceof TextBlock)
.map(b -> ((TextBlock) b).getText())
.collect(Collectors.joining(" "));
}
private ChatResponse createSafeResponse(
ChatResponse original,
ModerationResult violation) {
return ChatResponse.builder()
.model(original.getModel())
.content(List.of(new TextBlock(
"I apologize, but I cannot provide that response. " +
"Reason: " + violation.getReason())))
.finishReason(FinishReason.STOP)
.build();
}
}
17.6.5 组合使用多个Hook
// 创建多个Hook
Hook auditHook = new AuditLoggingHook(auditLogger, "APP-001");
Hook rateLimitHook = new RateLimitingHook(60); // 60次/分钟
Hook contentFilterHook = new ContentFilterHook(
moderator,
List.of("password\\s*=", "api[_-]?key"),
true
);
// 创建Agent时添加所有Hook
ReActAgent agent = ReActAgent.builder()
.name("SecureAgent")
.model(model)
.hooks(
rateLimitHook, // 优先级10:先检查速率限制
contentFilterHook, // 优先级20:内容过滤
auditHook // 优先级50:审计日志
)
.build();
17.7 组件开发最佳实践
17.7.1 设计原则
| 原则 | 说明 | 示例 |
|---|---|---|
| 不变性 | 尽量使用不可变对象 | Msg, ChatResponse等都是不可变的 |
| 防御性复制 | 返回集合时创建副本 | return new ArrayList<>(messages) |
| Null安全 | 使用Optional或明确的null检查 | Objects.requireNonNull() |
| 资源清理 | 实现AutoCloseable | 数据库连接、HTTP客户端 |
| 幂等性 | 相同输入产生相同输出 | 工具方法应该幂等 |
17.7.2 线程安全
// 使用线程安全集合
private final List<Msg> messages = new CopyOnWriteArrayList<>();
private final Map<String, Object> cache = new ConcurrentHashMap<>();
// 使用原子操作
private final AtomicInteger counter = new AtomicInteger(0);
private final AtomicReference<State> state = new AtomicReference<>(State.IDLE);
// 必要时使用同步
public synchronized void updateState(State newState) {
// 状态转换逻辑
}
// 使用volatile确保可见性
private volatile boolean initialized = false;
17.7.3 异常处理
@Tool(name = "safe_operation")
public String safeOperation(
@ToolParam(name = "input") String input
) {
try {
return doOperation(input);
} catch (SpecificException e) {
// 返回用户友好的错误信息
return "Operation failed: " + e.getUserMessage();
} catch (Exception e) {
// 记录日志但不暴露内部错误
log.error("Unexpected error in safeOperation", e);
return "An unexpected error occurred. Please try again.";
}
}
17.7.4 测试策略
// 单元测试自定义Memory
@Test
void testSlidingWindowMemory() {
SlidingWindowMemory memory = new SlidingWindowMemory(3);
memory.addMessage(Msg.user("msg1"));
memory.addMessage(Msg.user("msg2"));
memory.addMessage(Msg.user("msg3"));
memory.addMessage(Msg.user("msg4")); // 应该触发窗口滑动
List<Msg> messages = memory.getMessages();
assertEquals(3, messages.size());
assertEquals("msg2", extractText(messages.get(0)));
}
// 单元测试Hook
@Test
void testRateLimitingHook() {
RateLimitingHook hook = new RateLimitingHook(2); // 每分钟2次
PreCallEvent event = createMockPreCallEvent("TestAgent");
// 前两次应该成功
assertDoesNotThrow(() -> hook.onEvent(event).block());
assertDoesNotThrow(() -> hook.onEvent(event).block());
// 第三次应该失败
assertThrows(RateLimitExceededException.class,
() -> hook.onEvent(event).block());
}
17.8 本章小结
本章详细介绍了AgentScope-Java的组件扩展机制:
- Model扩展:通过继承
ChatModelBase并实现doStream方法来接入自定义LLM服务 - Memory扩展:实现
Memory接口创建自定义记忆策略,如滑动窗口、Token限制等 - Tool开发:使用
@Tool和@ToolParam注解定义工具,支持流式输出和自定义结果转换 - Formatter开发:实现
Formatter接口适配不同LLM提供商的消息格式 - Hook开发:实现
Hook接口创建自定义的执行拦截逻辑,支持审计、限流、过滤等功能
掌握这些扩展能力,你可以根据业务需求灵活定制AgentScope-Java的行为,构建满足特定场景的AI Agent应用。