导读:理论讲再多不如动手做一次。本文将带你从零开始构建一个实用的学术研究助手Agent,包含完整的代码实现和最佳实践。


系列文章导航

  1. AI智能体开发(一):从概念到架构设计
  2. AI智能体开发(二):技术栈选择与工具集成
  3. AI智能体开发(三):实战构建研究助手Agent
  4. AI智能体开发(四):进阶技巧与性能优化

项目概述

我们将构建一个学术研究助手Agent,它能够:

  • 智能搜索 - 根据主题搜索arXiv上的相关论文
  • 自动阅读 - 下载并解析PDF论文内容
  • 提取关键信息 - 识别研究方法、实验结果、结论
  • 生成研究报告 - 输出结构化的Markdown格式报告
  • 保存结果 - 将报告保存为文件,方便后续查阅

技术栈

  • 框架:LangChain
  • LLM:GPT-4 API
  • 向量库:Chroma(本地)
  • 工具:arxiv、pypdf、markdown

环境准备

Step 1: 创建项目

1
2
3
4
5
6
7
8
# 创建项目目录
mkdir research-agent
cd research-agent

# 创建虚拟环境
python -m venv venv
source venv/bin/activate # Linux/Mac
# venv\Scripts\activate # Windows

Step 2: 安装依赖

创建 requirements.txt

1
2
3
4
5
6
7
8
langchain>=0.1.0
langchain-openai>=0.0.5
chromadb>=0.4.0
arxiv>=2.0.0
pypdf>=3.17.0
markdown>=3.5.0
python-dotenv>=1.0.0
tenacity>=8.2.0

安装依赖:

1
pip install -r requirements.txt

Step 3: 配置环境变量

创建 .env 文件:

1
OPENAI_API_KEY=sk-your-api-key-here

创建 config.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
from dotenv import load_dotenv

load_dotenv()

class Config:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# LLM配置
LLM_MODEL = "gpt-4"
LLM_TEMPERATURE = 0.7
LLM_MAX_TOKENS = 4000

# 搜索配置
MAX_SEARCH_RESULTS = 5
MAX_PAPER_LENGTH = 3000 # 每篇论文最大读取字符数

# 输出配置
OUTPUT_DIR = "./reports"

config = Config()

# 确保输出目录存在
os.makedirs(config.OUTPUT_DIR, exist_ok=True)

核心组件实现

论文搜索工具

创建 tools/paper_search.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import arxiv
from typing import List, Dict
from config import config

def search_papers(query: str, max_results: int = None) -> List[Dict]:
"""
搜索arXiv学术论文

Args:
query: 搜索关键词
max_results: 最大返回结果数

Returns:
论文列表,每个论文包含标题、作者、摘要、链接等信息
"""
if max_results is None:
max_results = config.MAX_SEARCH_RESULTS

print(f"正在搜索论文: {query}")

# 创建搜索对象
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance,
sort_order=arxiv.SortOrder.Descending
)

papers = []
for result in search.results():
paper = {
'title': result.title,
'authors': [author.name for author in result.authors],
'summary': result.summary,
'published': result.published.strftime('%Y-%m-%d'),
'pdf_url': result.pdf_url,
'entry_id': result.entry_id,
'categories': result.categories
}
papers.append(paper)

print(f"- 找到 {len(papers)} 篇相关论文")
return papers


def format_paper_info(paper: Dict) -> str:
"""格式化单篇论文信息"""
authors_str = ", ".join(paper['authors'][:3]) # 只显示前3个作者
if len(paper['authors']) > 3:
authors_str += " et al."

return f"""
标题: {paper['title']}
作者: {authors_str}
发表日期: {paper['published']}
摘要: {paper['summary'][:500]}...
链接: {paper['entry_id']}
分类: {', '.join(paper['categories'][:3])}
""".strip()


def format_search_results(papers: List[Dict]) -> str:
"""格式化搜索结果"""
if not papers:
return "未找到相关论文"

formatted = []
for i, paper in enumerate(papers, 1):
formatted.append(f"\n{'='*60}\n论文 {i}:\n{format_paper_info(paper)}")

return "\n".join(formatted)

PDF阅读器

创建 tools/pdf_reader.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import requests
from pypdf import PdfReader
from io import BytesIO
from config import config

def download_and_read_pdf(pdf_url: str, max_length: int = None) -> str:
"""
下载并读取PDF文件内容

Args:
pdf_url: PDF文件的URL
max_length: 最大读取字符数

Returns:
PDF文本内容
"""
if max_length is None:
max_length = config.MAX_PAPER_LENGTH

try:
print(f"正在下载PDF: {pdf_url[:80]}...")

# 下载PDF
response = requests.get(pdf_url, timeout=30)
response.raise_for_status()

# 读取PDF内容
pdf_file = BytesIO(response.content)
reader = PdfReader(pdf_file)

# 提取文本
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)

full_text = "\n".join(text_parts)

# 限制长度
if len(full_text) > max_length:
full_text = full_text[:max_length] + "\n...(内容过长,已截断)"

print(f"- 成功读取PDF,共 {len(full_text)} 字符")
return full_text

except Exception as e:
print(f"- 读取PDF失败: {str(e)}")
return f"无法读取PDF内容: {str(e)}"


def extract_sections(text: str) -> Dict[str, str]:
"""
从论文文本中提取主要章节

Args:
text: 论文全文

Returns:
包含各章节内容的字典
"""
sections = {}

# 常见的论文章节标题
section_patterns = [
'abstract', 'introduction', 'background', 'related work',
'methodology', 'methods', 'approach', 'experiment',
'results', 'evaluation', 'discussion', 'conclusion',
'future work', 'references'
]

lines = text.split('\n')
current_section = None
current_content = []

for line in lines:
# 检测是否是章节标题
line_lower = line.lower().strip()
is_section = False

for pattern in section_patterns:
if pattern in line_lower and len(line) < 100:
# 保存之前的章节
if current_section:
sections[current_section] = '\n'.join(current_content)

current_section = pattern
current_content = []
is_section = True
break

if not is_section and current_section:
current_content.append(line)

# 保存最后一个章节
if current_section:
sections[current_section] = '\n'.join(current_content)

return sections

总结生成器

创建 utils/summarizer.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from langchain_openai import ChatOpenAI
from config import config

class PaperSummarizer:
"""论文总结器"""

def __init__(self):
self.llm = ChatOpenAI(
model=config.LLM_MODEL,
temperature=config.LLM_TEMPERATURE,
max_tokens=config.LLM_MAX_TOKENS
)

def summarize_abstract(self, abstract: str) -> str:
"""总结论文摘要"""
prompt = f"""
请用简洁的语言总结以下论文摘要,提取核心要点:

摘要内容:
{abstract}

要求:
1. 用2-3句话概括研究目标
2. 列出主要研究方法
3. 指出关键发现或贡献
4. 总长度控制在200字以内
"""

response = self.llm.invoke(prompt)
return response.content

def summarize_full_paper(self, paper_text: str, title: str) -> str:
"""总结完整论文"""
# 分段处理长文本
sections = self._split_into_chunks(paper_text, chunk_size=2000)

summaries = []
for i, chunk in enumerate(sections):
print(f"正在总结第 {i+1}/{len(sections)} 部分...")

prompt = f"""
你是专业的学术研究员。请总结以下论文章节的内容。

论文标题:{title}

章节内容:
{chunk}

要求:
1. 提取该章节的核心观点
2. 记录重要的实验数据或结果
3. 保持学术性和准确性
4. 长度控制在300-500字
"""

response = self.llm.invoke(prompt)
summaries.append(response.content)

# 合并所有部分的总结
combined_summary = "\n\n".join(summaries)

# 生成最终总结
final_prompt = f"""
请基于以下各章节的总结,生成一份完整的论文总结报告。

论文标题:{title}

各章节总结:
{combined_summary}

请按照以下结构组织报告:

## 研究背景与目标
(简要介绍研究领域和研究问题)

## 研究方法
(详细说明采用的研究方法和技术路线)

## 主要发现
(列出关键的实验结果和发现,使用要点形式)

## 创新点与贡献
(说明本研究的创新之处和对领域的贡献)

## 局限性与未来工作
(指出研究的局限性和未来可能的研究方向)

要求:
- 语言专业但易懂
- 重点突出,逻辑清晰
- 总长度控制在1500-2000字
"""

final_response = self.llm.invoke(final_prompt)
return final_response.content

def _split_into_chunks(self, text: str, chunk_size: int) -> list:
"""将长文本分割成小块"""
chunks = []
for i in range(0, len(text), chunk_size):
chunks.append(text[i:i+chunk_size])
return chunks

def extract_key_points(self, paper_text: str) -> list:
"""提取论文的关键要点"""
prompt = f"""
请从以下论文内容中提取5-8个关键要点。

论文内容:
{paper_text[:3000]}

要求:
1. 每个要点用一句话概括
2. 涵盖研究目标、方法、结果、结论
3. 使用简洁明了的语言
4. 以列表形式返回
"""

response = self.llm.invoke(prompt)

# 解析列表
lines = response.content.strip().split('\n')
key_points = [line.strip('-•* ') for line in lines if line.strip()]

return key_points[:8] # 最多返回8个要点

报告生成器

创建 utils/report_generator.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from datetime import datetime
from typing import List, Dict
from config import config
import os

class ReportGenerator:
"""研究报告生成器"""

def __init__(self):
self.output_dir = config.OUTPUT_DIR

def generate_research_report(
self,
topic: str,
papers: List[Dict],
summaries: List[str],
key_findings: List[str]
) -> str:
"""
生成结构化研究报告

Args:
topic: 研究主题
papers: 论文列表
summaries: 论文总结列表
key_findings: 关键发现

Returns:
Markdown格式的研究报告
"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

report = f"""# {topic} 研究报告

**生成时间**: {timestamp}
**研究论文数量**: {len(papers)}

---

## 📋 目录

1. [研究概述](#研究概述)
2. [文献综述](#文献综述)
3. [关键发现](#关键发现)
4. [研究方法分析](#研究方法分析)
5. [趋势与展望](#趋势与展望)
6. [参考文献](#参考文献)

---

## 研究概述

本次研究围绕 **"{topic}"** 这一主题展开,通过检索和分析最新的学术论文,旨在全面了解该领域的研究现状、主要方法和未来发展方向。

**研究范围**:
- 时间范围:最近3年的研究成果
- 数据来源:arXiv学术数据库
- 分析论文:{len(papers)} 篇高质量论文

**研究目标**:
1. 梳理该领域的主要研究方向
2. 总结常用的研究方法和技术
3. 识别当前的研究热点和趋势
4. 预测未来的发展方向

---

## 文献综述

"""

# 添加每篇论文的总结
for i, (paper, summary) in enumerate(zip(papers, summaries), 1):
authors_str = ", ".join(paper['authors'][:3])
if len(paper['authors']) > 3:
authors_str += " et al."

report += f"""### {i}. {paper['title']}

**作者**: {authors_str}
**发表日期**: {paper['published']}
**PDF**: {paper['pdf_url']}

**总结**:
{summary}

---

"""

# 添加关键发现
report += """## 关键发现

基于对上述论文的分析,我们总结出以下关键发现:

"""

for i, finding in enumerate(key_findings, 1):
report += f"**{i}.** {finding}\n\n"

# 添加研究方法分析
report += """## 研究方法分析

通过对这些论文的研究方法进行分析,我们发现以下几种方法在该领域应用广泛:

"""

report += self._analyze_methods(papers, summaries)

# 添加趋势与展望
report += """## 趋势与展望

基于当前研究现状,我们对该领域的未来发展做出以下预测:

"""

report += self._generate_future_outlook(topic, summaries)

# 添加参考文献
report += """## 参考文献

"""

for i, paper in enumerate(papers, 1):
authors_str = ", ".join(paper['authors'][:3])
if len(paper['authors']) > 3:
authors_str += " et al."

report += f"""{i}. {authors_str}. "{paper['title']}". {paper['published']}. {paper['pdf_url']}

"""

report += f"""---

**报告结束**

*本报告由AI研究助手自动生成,仅供参考。建议结合原始论文进行深入阅读。*
"""

return report

def _analyze_methods(self, papers: List[Dict], summaries: List[str]) -> str:
"""分析方法论"""
# 这里可以添加更复杂的分析逻辑
# 简化版本:直接返回通用描述
return """1. **实证研究方法** - 通过实验验证理论假设
2. **文献综述法** - 系统梳理现有研究成果
3. **案例分析法** - 深入分析典型案例
4. **定量分析法** - 使用统计方法分析数据
5. **定性分析法** - 通过访谈、观察等方式收集数据

具体到每篇论文,建议阅读原文获取详细的方法论描述。
"""

def _generate_future_outlook(self, topic: str, summaries: List[str]) -> str:
"""生成未来展望"""
# 这里可以使用LLM生成更智能的展望
return f"""1. **技术深化** - 现有方法将继续优化和完善
2. **跨学科融合** - 与其他领域的结合将产生新的研究方向
3. **应用拓展** - 研究成果将在更多实际场景中得到应用
4. **标准化进程** - 行业标准和规范将逐步建立
5. **开源生态** - 开源工具和平台将促进研究协作

对于"{topic}"这一主题,建议持续关注顶级会议和期刊的最新发表。
"""

def save_report(self, report: str, filename: str = None) -> str:
"""
保存报告到文件

Args:
report: 报告内容
filename: 文件名(可选)

Returns:
保存的文件路径
"""
if filename is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"research_report_{timestamp}.md"

filepath = os.path.join(self.output_dir, filename)

with open(filepath, 'w', encoding='utf-8') as f:
f.write(report)

print(f"报告已保存至: {filepath}")
return filepath

Agent核心实现

创建 agents/research_agent.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from langchain_openai import ChatOpenAI
from langchain.agents import create_react_agent, AgentExecutor
from langchain.tools import Tool
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

from tools.paper_search import search_papers, format_search_results
from tools.pdf_reader import download_and_read_pdf
from utils.summarizer import PaperSummarizer
from utils.report_generator import ReportGenerator
from config import config


class ResearchAgent:
"""学术研究助手Agent"""

def __init__(self):
# 初始化LLM
self.llm = ChatOpenAI(
model=config.LLM_MODEL,
temperature=config.LLM_TEMPERATURE,
max_tokens=config.LLM_MAX_TOKENS
)

# 初始化工具
self.summarizer = PaperSummarizer()
self.report_generator = ReportGenerator()

# 定义工具
self.tools = self._create_tools()

# 创建Prompt模板
self.prompt = self._create_prompt()

# 创建Agent
self.agent = create_react_agent(self.llm, self.tools, self.prompt)

# 创建执行器
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)

self.executor = AgentExecutor(
agent=self.agent,
tools=self.tools,
memory=self.memory,
verbose=True,
max_iterations=15,
max_execution_time=600 # 10分钟超时
)

# 存储研究结果
self.research_data = {
'topic': None,
'papers': [],
'summaries': [],
'key_findings': []
}

def _create_tools(self) -> list:
"""创建Agent可用的工具"""

def search_and_summarize(query: str) -> str:
"""搜索论文并生成总结"""
papers = search_papers(query)
if not papers:
return "未找到相关论文"

# 保存论文信息
self.research_data['topic'] = query
self.research_data['papers'] = papers

# 总结第一篇论文(简化版,实际可以总结多篇)
first_paper = papers[0]
summary = self.summarizer.summarize_abstract(first_paper['summary'])
self.research_data['summaries'].append(summary)

result = f"""找到 {len(papers)} 篇相关论文。

第一篇论文总结:
{summary}

完整论文列表:
{format_search_results(papers)}
"""
return result

def read_full_paper(url: str) -> str:
"""阅读完整论文"""
text = download_and_read_pdf(url)
if "无法读取" in text:
return text

# 提取关键要点
key_points = self.summarizer.extract_key_points(text)
self.research_data['key_findings'].extend(key_points)

return f"""论文内容已读取({len(text)} 字符)。

关键要点:
{chr(10).join(['- ' + kp for kp in key_points[:5]])}
"""

def generate_final_report(topic: str) -> str:
"""生成最终研究报告"""
if not self.research_data['papers']:
return "请先进行论文搜索"

report = self.report_generator.generate_research_report(
topic=topic or self.research_data['topic'],
papers=self.research_data['papers'],
summaries=self.research_data['summaries'],
key_findings=self.research_data['key_findings'][:10]
)

# 保存报告
filepath = self.report_generator.save_report(report)

return f"研究报告已生成并保存至: {filepath}\n\n报告预览:\n{report[:500]}..."

tools = [
Tool(
name="search_papers",
func=search_and_summarize,
description="搜索arXiv学术论文并生成总结。输入:研究主题或关键词"
),
Tool(
name="read_paper",
func=read_full_paper,
description="下载并阅读完整PDF论文。输入:论文的PDF URL"
),
Tool(
name="generate_report",
func=generate_final_report,
description="基于已收集的论文生成研究报告。输入:研究主题"
)
]

return tools

def _create_prompt(self) -> PromptTemplate:
"""创建Agent的Prompt模板"""
template = """你是一个专业的学术研究助手。你的任务是帮助用户研究特定主题并生成高质量的研究报告。

你可以使用以下工具:
{tools}

工作流程:
1. 首先使用 search_papers 搜索相关论文
2. 如果需要深入了解,使用 read_paper 阅读完整论文
3. 最后使用 generate_report 生成研究报告

使用工具的格式:
Thought: 思考下一步该做什么
Action: 工具名称
Action Input: 工具输入
Observation: 工具返回结果
... (可以重复Thought/Action/Observation多次)
Thought: 我现在知道最终答案
Final Answer: 给用户的最终回答

开始!

问题:{input}
{agent_scratchpad}
"""

return PromptTemplate.from_template(template)

def research(self, topic: str) -> str:
"""
执行研究任务

Args:
topic: 研究主题

Returns:
研究结果
"""
print(f"\n🚀 开始研究主题: {topic}\n")

query = f"""请帮我研究以下主题:{topic}

要求:
1. 搜索相关的学术论文
2. 阅读并总结至少1篇重要论文
3. 生成一份结构化的研究报告

请开始执行。"""

result = self.executor.invoke({"input": query})

print("\n- 研究完成!\n")
return result["output"]

运行Agent

创建 main.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from agents.research_agent import ResearchAgent

def main():
"""主函数"""
print("="*60)
print("🎓 AI学术研究助手")
print("="*60)

# 创建Agent
agent = ResearchAgent()

# 获取用户输入
topic = input("\n请输入研究主题: ").strip()

if not topic:
topic = "Transformer模型在自然语言处理中的最新进展"
print(f"使用默认主题: {topic}")

# 执行研究
try:
result = agent.research(topic)
print("\n" + "="*60)
print("研究结果:")
print("="*60)
print(result)

except Exception as e:
print(f"\n- 研究过程中出现错误: {str(e)}")
import traceback
traceback.print_exc()


if __name__ == "__main__":
main()

运行:

1
python main.py

运行示例

输入

1
请输入研究主题: Transformer模型在自然语言处理中的最新进展

预期输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
============================================================
🎓 AI学术研究助手
============================================================

🚀 开始研究主题: Transformer模型在自然语言处理中的最新进展

> Entering new AgentExecutor chain...

Thought: 我需要先搜索关于Transformer模型的最新论文
Action: search_papers
Action Input: Transformer natural language processing 2025 2026

正在搜索论文: Transformer natural language processing 2025 2026
- 找到 5 篇相关论文

Observation: 找到 5 篇相关论文。

第一篇论文总结:
该论文研究了Transformer架构在NLP任务中的最新优化方法...

完整论文列表:
[论文列表...]

Thought: 我需要阅读第一篇论文的完整内容以获取更多信息
Action: read_paper
Action Input: https://arxiv.org/pdf/xxxx.xxxxx.pdf

正在下载PDF...
- 成功读取PDF,共 15000 字符

Observation: 论文内容已读取(15000 字符)。

关键要点:
- 提出了新的注意力机制优化方法
- 在多个基准测试中取得SOTA结果
- ...

Thought: 现在我有足够的信息来生成研究报告
Action: generate_report
Action Input: Transformer模型在自然语言处理中的最新进展

报告已保存至: ./reports/research_report_20260518_100000.md

Observation: 研究报告已生成并保存至: ./reports/research_report_20260518_100000.md

Thought: 我现在知道最终答案
Final Answer: 我已完成了对"Transformer模型在自然语言处理中的最新进展"的研究...

> Finished chain.

- 研究完成!

============================================================
研究结果:
============================================================
[完整的研究报告内容...]

测试与优化

单元测试

创建 tests/test_research_agent.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import unittest
from tools.paper_search import search_papers, format_paper_info
from tools.pdf_reader import download_and_read_pdf
from utils.summarizer import PaperSummarizer

class TestPaperSearch(unittest.TestCase):
"""测试论文搜索功能"""

def test_search_papers(self):
"""测试搜索论文"""
papers = search_papers("machine learning", max_results=3)
self.assertIsInstance(papers, list)
self.assertLessEqual(len(papers), 3)

if papers:
self.assertIn('title', papers[0])
self.assertIn('authors', papers[0])
self.assertIn('summary', papers[0])

def test_format_paper_info(self):
"""测试格式化论文信息"""
paper = {
'title': 'Test Paper',
'authors': ['Author1', 'Author2'],
'summary': 'This is a test summary.',
'published': '2026-05-18',
'pdf_url': 'https://example.com/paper.pdf',
'entry_id': 'https://arxiv.org/abs/xxxx.xxxxx',
'categories': ['cs.AI']
}

formatted = format_paper_info(paper)
self.assertIn('Test Paper', formatted)
self.assertIn('Author1', formatted)

class TestPaperSummarizer(unittest.TestCase):
"""测试论文总结器"""

def setUp(self):
self.summarizer = PaperSummarizer()

def test_extract_key_points(self):
"""测试提取关键要点"""
text = """
This paper presents a novel approach to machine learning.
We propose a new algorithm that achieves state-of-the-art results.
Our experiments show significant improvements over baseline methods.
The key contributions are: 1) New architecture design, 2) Efficient training method.
"""

key_points = self.summarizer.extract_key_points(text)
self.assertIsInstance(key_points, list)
self.assertGreater(len(key_points), 0)

if __name__ == '__main__':
unittest.main()

运行测试:

1
python -m pytest tests/ -v

最佳实践

错误处理

1
2
3
4
5
6
7
8
9
from tenacity import retry, stop_after_attempt, wait_exponential

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10)
)
def robust_api_call(func, *args, **kwargs):
"""带重试机制的API调用"""
return func(*args, **kwargs)

性能优化预告:上面的重试机制虽然简单有效,但在高并发场景下可能会导致资源浪费。第四篇会介绍更高级的错误处理和熔断机制,以及如何实现优雅降级。

成本控制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class TokenTracker:
"""Token使用追踪器"""

def __init__(self, budget: int = 100000):
self.budget = budget
self.used = 0

def track(self, response):
"""追踪Token使用"""
tokens = response.usage_metadata['total_tokens']
self.used += tokens

if self.used > self.budget * 0.8:
print(f"Token使用已达{self.used}/{self.budget} (80%)")

if self.used > self.budget:
raise Exception(f"Token预算超支: {self.used}/{self.budget}")

缓存机制

1
2
3
4
5
6
7
8
9
10
11
from functools import lru_cache
import hashlib

@lru_cache(maxsize=100)
def cached_search(query: str):
"""缓存搜索结果"""
return search_papers(query)

def get_cache_key(query: str) -> str:
"""生成缓存键"""
return hashlib.md5(query.encode()).hexdigest()

本篇总结

本文从零开始构建了一个完整的学术研究助手Agent,实现了:

核心功能

  • 论文搜索工具(arXiv集成)
  • PDF阅读器(自动下载和解析)
  • 智能总结器(分段处理长文本)
  • 报告生成器(结构化Markdown输出)
  • ReAct Agent(自主决策和执行)

工程实践

  • 模块化设计(清晰的代码结构)
  • 错误处理和重试机制
  • 成本控制和监控
  • 单元测试覆盖
  • 缓存优化