190 lines
7.1 KiB
Python
190 lines
7.1 KiB
Python
import json
|
||
import logging
|
||
from datetime import datetime, date
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from ..models.news import RawNews, ProcessedNews, LLMConfig, NewsSource, SystemLog
|
||
from ..crawler.rss_fetcher import fetch_rss
|
||
from .llm_client import LLMClient
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
SYSTEM_PROMPT = "你是医药行业资深分析师,擅长解读全球医药政策、临床研究、行业动态。"
|
||
|
||
ANALYSIS_PROMPT = """分析以下新闻,返回严格的 JSON 格式结果,不要包含任何其他文字。
|
||
|
||
新闻标题:{title}
|
||
新闻内容:{content}
|
||
新闻语言:{language}
|
||
|
||
返回格式:
|
||
{{
|
||
"is_medical_related": true,
|
||
"title_zh": "中文标题(英文原文请翻译成简洁中文)",
|
||
"summary": "中文摘要(100-150字,客观陈述核心内容)",
|
||
"opinion": "核心观点或行业影响(50-100字,分析性语言,点明实际意义)",
|
||
"keywords": ["关键词1", "关键词2", "关键词3", "关键词4", "关键词5"],
|
||
"importance_score": 8.5,
|
||
"importance_reason": "评分理由(30字内)",
|
||
"category": "药品监管"
|
||
}}
|
||
|
||
category 只能是以下四个之一:药品监管 / 临床研究 / 行业动态 / 政策法规
|
||
|
||
importance_score 评分标准(1-10):
|
||
9-10:重大监管决定 / 突破性研究 / 影响整个行业的政策
|
||
7-8 :行业重要动态,有明显商业或学术价值
|
||
5-6 :常规行业新闻,有一定参考价值
|
||
1-4 :普通资讯,信息价值有限
|
||
"""
|
||
|
||
|
||
async def _log(db: AsyncSession, level: str, event_type: str, message: str):
|
||
db.add(SystemLog(level=level, event_type=event_type, message=message))
|
||
await db.commit()
|
||
|
||
|
||
async def _get_active_llm(db: AsyncSession) -> LLMConfig | None:
|
||
result = await db.execute(select(LLMConfig).where(LLMConfig.is_active == True).limit(1))
|
||
return result.scalar_one_or_none()
|
||
|
||
|
||
async def _analyze_article(client: LLMClient, title: str, content: str, language: str) -> dict | None:
|
||
prompt = ANALYSIS_PROMPT.format(
|
||
title=title,
|
||
content=content[:2000] if content else "(无正文)",
|
||
language="中文" if language == "zh" else "英文",
|
||
)
|
||
try:
|
||
raw = await client.complete(SYSTEM_PROMPT, prompt)
|
||
raw = raw.strip()
|
||
if raw.startswith("```"):
|
||
raw = raw.split("```")[1]
|
||
if raw.startswith("json"):
|
||
raw = raw[4:]
|
||
return json.loads(raw)
|
||
except Exception as e:
|
||
logger.warning(f"LLM parse error: {e}")
|
||
return None
|
||
|
||
|
||
async def _select_top_10(db: AsyncSession, target: date):
|
||
"""Reset featured flags and elect TOP 10 with category diversity."""
|
||
result = await db.execute(
|
||
select(ProcessedNews)
|
||
.where(func.date(ProcessedNews.processed_at) == target)
|
||
.order_by(ProcessedNews.importance_score.desc())
|
||
)
|
||
all_news = result.scalars().all()
|
||
|
||
# Reset
|
||
for n in all_news:
|
||
n.is_featured = False
|
||
n.featured_rank = None
|
||
|
||
categories = ["药品监管", "临床研究", "行业动态", "政策法规"]
|
||
selected: list[ProcessedNews] = []
|
||
seen_cats: set[str] = set()
|
||
|
||
# First pass: one guaranteed per category
|
||
for cat in categories:
|
||
for n in all_news:
|
||
if n.category == cat and cat not in seen_cats and n not in selected:
|
||
selected.append(n)
|
||
seen_cats.add(cat)
|
||
break
|
||
|
||
# Second pass: fill up to 10 by score
|
||
for n in all_news:
|
||
if len(selected) >= 10:
|
||
break
|
||
if n not in selected:
|
||
selected.append(n)
|
||
|
||
for rank, n in enumerate(selected, start=1):
|
||
n.is_featured = True
|
||
n.featured_rank = rank
|
||
|
||
await db.commit()
|
||
return len(selected)
|
||
|
||
|
||
async def run_daily_pipeline(db: AsyncSession):
|
||
await _log(db, "INFO", "pipeline_start", "每日流水线启动")
|
||
|
||
llm_cfg = await _get_active_llm(db)
|
||
if not llm_cfg:
|
||
await _log(db, "ERROR", "pipeline_error", "未找到激活的 LLM 配置,请在管理后台配置")
|
||
return
|
||
|
||
client = LLMClient(
|
||
provider=llm_cfg.provider,
|
||
api_key=llm_cfg.api_key,
|
||
base_url=llm_cfg.base_url,
|
||
model=llm_cfg.model_name,
|
||
)
|
||
|
||
# ── 1. 抓取 ──────────────────────────────────────────────────────────────
|
||
sources_result = await db.execute(select(NewsSource).where(NewsSource.is_active == True))
|
||
sources = sources_result.scalars().all()
|
||
raw_added = 0
|
||
|
||
for src in sources:
|
||
items = await fetch_rss(src.url)
|
||
for item in items:
|
||
exists = await db.execute(select(RawNews.id).where(RawNews.url == item["url"]))
|
||
if exists.scalar_one_or_none():
|
||
continue
|
||
db.add(RawNews(
|
||
source_id=src.id,
|
||
title=item["title"],
|
||
url=item["url"],
|
||
raw_content=item["content"],
|
||
published_at=item["published_at"],
|
||
))
|
||
raw_added += 1
|
||
await db.commit()
|
||
|
||
await _log(db, "INFO", "crawl_done", f"抓取完成,新增 {raw_added} 条原始新闻")
|
||
|
||
# ── 2. AI 处理 ────────────────────────────────────────────────────────────
|
||
pending_result = await db.execute(
|
||
select(RawNews).join(RawNews.source).where(RawNews.status == "pending").limit(120)
|
||
)
|
||
pending = pending_result.scalars().all()
|
||
processed_count = 0
|
||
skipped_count = 0
|
||
|
||
for raw in pending:
|
||
language = raw.source.language if raw.source else "zh"
|
||
analysis = await _analyze_article(client, raw.title, raw.raw_content or "", language)
|
||
|
||
if not analysis or not analysis.get("is_medical_related"):
|
||
raw.status = "skipped"
|
||
skipped_count += 1
|
||
else:
|
||
db.add(ProcessedNews(
|
||
raw_news_id=raw.id,
|
||
title_zh=analysis.get("title_zh", raw.title),
|
||
summary=analysis.get("summary", ""),
|
||
opinion=analysis.get("opinion"),
|
||
keywords=analysis.get("keywords", []),
|
||
importance_score=float(analysis.get("importance_score", 5.0)),
|
||
importance_reason=analysis.get("importance_reason"),
|
||
category=analysis.get("category", "行业动态"),
|
||
source_name=raw.source.name if raw.source else "",
|
||
source_url=raw.url,
|
||
published_at=raw.published_at,
|
||
))
|
||
raw.status = "processed"
|
||
processed_count += 1
|
||
|
||
await db.commit()
|
||
|
||
await _log(db, "INFO", "process_done", f"AI 处理完成:{processed_count} 条入库,{skipped_count} 条跳过")
|
||
|
||
# ── 3. 精选 TOP 10 ────────────────────────────────────────────────────────
|
||
featured = await _select_top_10(db, date.today())
|
||
await _log(db, "INFO", "pipeline_done", f"流水线完成,精选 {featured} 条入今日 TOP 10")
|