diff --git a/app/chain/skills.py b/app/chain/skills.py index e5dc346b..beb1f265 100644 --- a/app/chain/skills.py +++ b/app/chain/skills.py @@ -260,9 +260,15 @@ class SkillsChain(ChainBase): request.view = "market" request.market_page = 0 request.awaiting_input = None + elif action == "sources": + request.view = "sources" + request.awaiting_input = None elif action == "search": request.view = "market" request.awaiting_input = "market-search" + elif action == "source-add": + request.view = "sources" + request.awaiting_input = "source-add" elif action == "clear-search": self._clear_market_search(request) elif action == "refresh": @@ -328,6 +334,19 @@ class SkillsChain(ChainBase): if not success: # 保持当前页 pass + elif action == "source-remove" and index: + request.view = "sources" + request.awaiting_input = None + success, message = self._remove_market_source(index) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) self._render_interaction( request=request, @@ -386,6 +405,59 @@ class SkillsChain(ChainBase): ) return True + add_source = self._extract_market_source_input(normalized) + remove_source_match = re.match( + r"^(?:删除源|移除源|删除仓库|移除仓库|remove source)\s*(\d+)$", + normalized, + re.IGNORECASE, + ) + + if add_source: + request.view = "sources" + request.awaiting_input = None + _, message = self.skillhelper.add_custom_market_source(add_source) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if remove_source_match: + request.view = "sources" + request.awaiting_input = None + _, message = self._remove_market_source( + page_index=int(remove_source_match.group(1)) + ) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + if lowered in {"刷新", "refresh"}: request.awaiting_input = None self._render_interaction( @@ -433,6 +505,8 @@ class SkillsChain(ChainBase): request.view = "installed" elif lowered in {"2", "市场", "market"}: request.view = "market" + elif lowered in {"3", "技能源", "源", "sources", "source"}: + request.view = "sources" elif self._extract_market_search_query(normalized): self._apply_market_search( request, @@ -445,7 +519,7 @@ class SkillsChain(ChainBase): source=source, userid=userid, username=username, - title="请输入 1 查看已安装技能,2 查看技能市场,或回复 刷新/退出", + title="请输入 1 查看已安装技能,2 查看技能市场,3 管理技能源,或回复 刷新/退出", ) ) return True @@ -458,6 +532,30 @@ class SkillsChain(ChainBase): ) return True + if request.awaiting_input == "source-add": + if lowered in {"取消", "cancel"}: + request.awaiting_input = None + else: + _, message = self.skillhelper.add_custom_market_source(normalized) + request.awaiting_input = None + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + if lowered in {"清除搜索", "取消搜索", "clear", "clear search"}: if request.view == "market" or request.market_query: self._clear_market_search(request) @@ -524,7 +622,7 @@ class SkillsChain(ChainBase): return True if request.view == "installed" and remove_match: request.awaiting_input = None - success, message = self._remove_local_skill( + _, message = self._remove_local_skill( request=request, page_index=int(remove_match.group(1)), ) @@ -598,6 +696,18 @@ class SkillsChain(ChainBase): return False, f"技能 {target.id} 是内置技能,不能删除" return self.skillhelper.remove_local_skill(target.id) + def _remove_market_source(self, page_index: int) -> Tuple[bool, str]: + """ + 按当前源列表序号删除自定义技能源,避免误删内置默认源。 + """ + sources = self.skillhelper.list_market_source_entries() + if page_index < 1 or page_index > len(sources): + return False, "删除源序号无效" + target = sources[page_index - 1] + if not target.removable: + return False, f"技能源 {target.label} 是内置默认源,不能删除" + return self.skillhelper.remove_custom_market_source(target.source) + def _render_interaction( self, request: PendingSkillsInteraction, @@ -622,6 +732,11 @@ class SkillsChain(ChainBase): request=request, force_market_refresh=force_market_refresh, ) + elif request.view == "sources": + title, text, buttons = self._build_sources_view( + request=request, + force_market_refresh=force_market_refresh, + ) else: title, text, buttons = self._build_root_view( request=request, @@ -654,22 +769,26 @@ class SkillsChain(ChainBase): for skill in self.skillhelper.list_market_skills(force=force_market_refresh) if not skill.installed ] - sources = self.skillhelper.get_market_sources() + source_entries = self.skillhelper.list_market_source_entries() source_lines = [] - for index, source in enumerate(sources, start=1): - source_lines.append(f"{index}. {self.skillhelper.describe_market_source(source)}") + for index, source_entry in enumerate(source_entries, start=1): + state = "内置" if source_entry.builtin else "自定义" + source_lines.append( + f"{index}. {source_entry.label}({state})" + ) text_lines = [ f"已安装技能:{len(local_skills)}", f"市场可安装技能:{len(market_skills)}", ] if source_lines: - text_lines.extend(["", "公开技能源:", *source_lines]) + text_lines.extend(["", "当前技能源:", *source_lines]) text_lines.extend( [ "", "1. 查看已安装技能", "2. 浏览技能市场", + "3. 管理技能源", "回复 刷新 重新获取市场数据,回复 退出 结束交互", ] ) @@ -679,6 +798,7 @@ class SkillsChain(ChainBase): buttons = [ [{"text": "已安装技能", "callback_data": f"skills:{request.request_id}:installed"}], [{"text": "技能市场", "callback_data": f"skills:{request.request_id}:market"}], + [{"text": "技能源管理", "callback_data": f"skills:{request.request_id}:sources"}], [ {"text": "刷新市场", "callback_data": f"skills:{request.request_id}:refresh"}, {"text": "关闭", "callback_data": f"skills:{request.request_id}:close"}, @@ -848,6 +968,79 @@ class SkillsChain(ChainBase): ) return "技能市场", "\n".join(text_lines), buttons + def _build_sources_view( + self, + request: PendingSkillsInteraction, + force_market_refresh: bool = False, # noqa: ARG002 + ) -> Tuple[str, str, Optional[List[List[dict]]]]: + """ + 构建技能源管理视图,提供自定义 GitHub 源的增删入口。 + """ + sources = self.skillhelper.list_market_source_entries() + custom_count = len([source for source in sources if not source.builtin]) + text_lines = [ + f"当前技能源:{len(sources)}", + f"自定义技能源:{custom_count}", + ] + if request.awaiting_input == "source-add": + text_lines.extend( + [ + "", + "添加输入中:直接回复 GitHub 仓库地址即可。", + "支持 owner/repo、https://github.com/owner/repo,或 /tree// 形式。", + "回复 取消 结束输入。", + ] + ) + + if not sources: + text_lines.extend(["", "当前没有可用技能源"]) + else: + for index, market_source in enumerate(sources, start=1): + state = "自定义可删" if market_source.removable else "内置默认" + text_lines.extend( + [ + "", + f"{index}. {market_source.label}({state})", + self._truncate(market_source.source, limit=200), + ] + ) + + text_lines.extend( + [ + "", + "回复 添加源 添加自定义源,回复 删除源 <序号> 删除自定义源,回复 返回 回到菜单,回复 退出 结束交互", + ] + ) + + buttons = None + if self._supports_interactive_buttons(request.channel): + buttons = [ + [ + { + "text": "添加自定义源", + "callback_data": f"skills:{request.request_id}:source-add", + } + ] + ] + for index, market_source in enumerate(sources, start=1): + if not market_source.removable: + continue + buttons.append( + [ + { + "text": f"删除 {index}", + "callback_data": f"skills:{request.request_id}:source-remove:{index}", + } + ] + ) + buttons.append( + [ + {"text": "返回", "callback_data": f"skills:{request.request_id}:root"}, + {"text": "关闭", "callback_data": f"skills:{request.request_id}:close"}, + ] + ) + return "技能源管理", "\n".join(text_lines), buttons + @staticmethod def _truncate(text: str, limit: int = 140) -> str: """ @@ -975,7 +1168,9 @@ class SkillsChain(ChainBase): return "请输入 搜索 <关键词>、清除搜索、安装 <序号>、刷新、n、p、返回 或 退出" if view == "installed": return "请输入 删除 <序号>、n、p、返回 或 退出" - return "请输入 1、2、搜索 <关键词>、刷新 或 退出" + if view == "sources": + return "请输入 添加源 、删除源 <序号>、返回 或 退出" + return "请输入 1、2、3、搜索 <关键词>、刷新 或 退出" def _get_market_skills( self, @@ -1008,6 +1203,21 @@ class SkillsChain(ChainBase): match = re.match(r"^(?:搜索|查找|查)\s+(.+)$", normalized) return match.group(1).strip() if match else "" + @staticmethod + def _extract_market_source_input(text: str) -> str: + """ + 从文本命令中提取自定义技能源地址。 + """ + normalized = (text or "").strip() + if not normalized: + return "" + match = re.match( + r"^(?:添加源|新增源|添加仓库|新增仓库|add source)\s+(.+)$", + normalized, + re.IGNORECASE, + ) + return match.group(1).strip() if match else "" + @staticmethod def _apply_market_search( request: PendingSkillsInteraction, diff --git a/app/helper/skill.py b/app/helper/skill.py index 4879e691..1751e846 100644 --- a/app/helper/skill.py +++ b/app/helper/skill.py @@ -53,6 +53,14 @@ class SkillInfo: removable: bool = False +@dataclass +class SkillMarketSource: + source: str + label: str + builtin: bool = True + removable: bool = False + + class SkillHelper(metaclass=WeakSingleton): """ 技能市场与本地技能管理 @@ -81,6 +89,16 @@ class SkillHelper(metaclass=WeakSingleton): return [] return [item.strip() for item in settings.SKILL_MARKET.split(",") if item.strip()] + @staticmethod + def get_default_market_sources() -> List[str]: + """ + 返回系统默认的技能市场列表,用于区分内置源和用户追加源。 + """ + default_value = type(settings).model_fields["SKILL_MARKET"].default + if not default_value: + return [] + return [item.strip() for item in str(default_value).split(",") if item.strip()] + @staticmethod def _ensure_user_skills_dir() -> Path: """ @@ -90,6 +108,31 @@ class SkillHelper(metaclass=WeakSingleton): skill_dir.mkdir(parents=True, exist_ok=True) return skill_dir + @staticmethod + def _canonical_market_source(source: str) -> Optional[str]: + """ + 生成市场源的规范化值,用于去重、默认源比对和持久化。 + """ + normalized = (source or "").strip() + if not normalized: + return None + + registry = SkillHelper._parse_market_registry(normalized) + if registry: + return registry["registry_url"].rstrip("/") + + repo = SkillHelper._parse_market_repo(normalized) + if repo: + # 对 GitHub 仓库保留分支和技能根目录,避免不同路径的技能仓库混淆。 + if repo["branch"]: + return ( + f"{repo['repo_url']}/tree/" + f"{repo['branch']}/{repo['root_path'].strip('/')}" + ).rstrip("/") + return repo["repo_url"].rstrip("/") + + return normalized.rstrip("/") + @staticmethod def _build_repo_source_label(repo_name: Optional[str]) -> str: """ @@ -125,6 +168,98 @@ class SkillHelper(metaclass=WeakSingleton): return self._build_repo_source_label(repo.get("repo_name")) return source + def list_market_source_entries(self) -> List[SkillMarketSource]: + """ + 返回当前技能源及其是否属于内置默认源的展示信息。 + """ + default_keys = { + self._canonical_market_source(item) for item in self.get_default_market_sources() + } + results: List[SkillMarketSource] = [] + for source in self.get_market_sources(): + source_key = self._canonical_market_source(source) + builtin = source_key in default_keys + results.append( + SkillMarketSource( + source=source, + label=self.describe_market_source(source), + builtin=builtin, + removable=not builtin, + ) + ) + return results + + @staticmethod + def _persist_market_sources(sources: List[str]) -> Tuple[bool, str]: + """ + 将技能源列表写回配置文件,并同步更新内存中的 settings。 + """ + filtered_sources = [item.strip() for item in sources if item and item.strip()] + success, message = settings.update_setting( + key="SKILL_MARKET", + value=",".join(filtered_sources), + ) + if success is False: + return False, message + return True, message + + def add_custom_market_source(self, source: str) -> Tuple[bool, str]: + """ + 添加自定义 GitHub 技能源,支持 owner/repo 与 GitHub URL 两种写法。 + """ + repo = self._parse_market_repo(source) + if not repo: + return ( + False, + "仅支持 GitHub skills 仓库,示例:openai/skills 或 https://github.com/openai/skills", + ) + + canonical_source = self._canonical_market_source(source) + if not canonical_source: + return False, "技能源地址不能为空" + + existing_keys = { + self._canonical_market_source(item) for item in self.get_market_sources() + } + if canonical_source in existing_keys: + return False, "该技能源已存在" + + current_sources = self.get_market_sources() + success, message = self._persist_market_sources( + current_sources + [canonical_source] + ) + if not success: + return False, message + return True, f"已添加技能源:{self.describe_market_source(canonical_source)}" + + def remove_custom_market_source(self, source: str) -> Tuple[bool, str]: + """ + 删除一个自定义技能源,内置默认源不允许移除。 + """ + canonical_source = self._canonical_market_source(source) + if not canonical_source: + return False, "技能源地址无效" + + default_keys = { + self._canonical_market_source(item) for item in self.get_default_market_sources() + } + if canonical_source in default_keys: + return False, f"技能源 {self.describe_market_source(source)} 是内置默认源,不能删除" + + current_sources = self.get_market_sources() + remaining_sources = [ + item + for item in current_sources + if self._canonical_market_source(item) != canonical_source + ] + if len(remaining_sources) == len(current_sources): + return False, "技能源不存在" + + success, message = self._persist_market_sources(remaining_sources) + if not success: + return False, message + return True, f"已删除技能源:{self.describe_market_source(source)}" + @staticmethod def _normalize_repo_url(repo_url: str) -> Optional[str]: """ diff --git a/tests/test_skills_command.py b/tests/test_skills_command.py index ba4f8766..f3b8f410 100644 --- a/tests/test_skills_command.py +++ b/tests/test_skills_command.py @@ -12,10 +12,16 @@ setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list) sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc")) setattr(sys.modules["transmission_rpc"], "File", object) sys.modules.setdefault("psutil", ModuleType("psutil")) +sys.modules.setdefault("aioshutil", ModuleType("aioshutil")) from app.chain.message import MessageChain from app.chain.skills import SkillsChain, skills_interaction_manager -from app.helper.skill import SkillHelper, SkillInfo +from app.helper.skill import ( + SkillHelper, + SkillInfo, + SkillMarketSource, + settings as skill_settings, +) from app.schemas.types import MessageChannel @@ -289,6 +295,70 @@ class TestSkillsCommand(unittest.TestCase): self.assertEqual(local_skills[0].registry_name, "ClawHub") self.assertEqual(local_skills[0].source_label, "社区注册表 · ClawHub") + def test_skillhelper_lists_market_sources_and_marks_custom_entries(self): + helper = SkillHelper() + + with patch.object( + helper, + "get_market_sources", + return_value=[ + "https://clawhub.ai", + "https://github.com/openai/skills", + "https://github.com/acme/custom-skills", + ], + ), patch.object( + helper, + "get_default_market_sources", + return_value=[ + "https://clawhub.ai", + "https://github.com/openai/skills", + ], + ): + sources = helper.list_market_source_entries() + + self.assertEqual(len(sources), 3) + self.assertTrue(sources[0].builtin) + self.assertTrue(sources[1].builtin) + self.assertFalse(sources[2].builtin) + self.assertTrue(sources[2].removable) + self.assertEqual(sources[2].label, "仓库来源 · acme/custom-skills") + + def test_skillhelper_add_custom_market_source_updates_setting(self): + helper = SkillHelper() + + with patch.object( + helper, + "get_market_sources", + return_value=["https://github.com/openai/skills"], + ), patch.object( + type(skill_settings), + "update_setting", + return_value=(True, ""), + ) as update_setting: + success, message = helper.add_custom_market_source("acme/custom-skills") + + self.assertTrue(success) + self.assertIn("acme/custom-skills", message) + update_setting.assert_called_once_with( + key="SKILL_MARKET", + value="https://github.com/openai/skills,https://github.com/acme/custom-skills", + ) + + def test_skillhelper_remove_custom_market_source_blocks_builtin(self): + helper = SkillHelper() + + with patch.object( + helper, + "get_default_market_sources", + return_value=["https://github.com/openai/skills"], + ): + success, message = helper.remove_custom_market_source( + "https://github.com/openai/skills" + ) + + self.assertFalse(success) + self.assertIn("内置默认源", message) + def test_skills_chain_market_view_marks_clawhub_as_community_source(self): chain = SkillsChain() request = skills_interaction_manager.create_or_replace( @@ -379,14 +449,35 @@ class TestSkillsCommand(unittest.TestCase): chain.skillhelper, "list_market_skills", return_value=[] ), patch.object( chain.skillhelper, - "get_market_sources", - return_value=["https://clawhub.ai", "https://github.com/openai/skills"], + "list_market_source_entries", + return_value=[ + SkillMarketSource( + source="https://clawhub.ai", + label="社区注册表 · ClawHub", + builtin=True, + removable=False, + ), + SkillMarketSource( + source="https://github.com/openai/skills", + label="官方仓库 · openai/skills", + builtin=True, + removable=False, + ), + SkillMarketSource( + source="https://github.com/acme/custom-skills", + label="仓库来源 · acme/custom-skills", + builtin=False, + removable=True, + ), + ], ): title, text, _buttons = chain._build_root_view(request=request) self.assertEqual(title, "技能管理") self.assertIn("社区注册表 · ClawHub", text) self.assertIn("官方仓库 · openai/skills", text) + self.assertIn("仓库来源 · acme/custom-skills", text) + self.assertIn("3. 管理技能源", text) def test_skills_chain_callback_enters_search_input_mode(self): chain = SkillsChain() @@ -461,6 +552,131 @@ class TestSkillsCommand(unittest.TestCase): self.assertIsNone(request.awaiting_input) render.assert_called_once() + def test_skills_chain_callback_enters_source_add_mode(self): + chain = SkillsChain() + request = skills_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + ) + + with patch.object(chain, "_render_interaction") as render: + handled = chain.handle_callback_interaction( + callback_data=f"skills:{request.request_id}:source-add", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + ) + + self.assertTrue(handled) + self.assertEqual(request.view, "sources") + self.assertEqual(request.awaiting_input, "source-add") + render.assert_called_once() + + def test_skills_chain_followup_text_adds_custom_market_source(self): + chain = SkillsChain() + request = skills_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + ) + request.view = "sources" + request.awaiting_input = "source-add" + + with patch.object( + chain.skillhelper, + "add_custom_market_source", + return_value=(True, "已添加技能源:仓库来源 · acme/custom-skills"), + ) as add_source, patch.object(chain, "_render_interaction") as render, patch.object( + chain, "post_message" + ) as post_message: + handled = chain.handle_text_interaction( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="acme/custom-skills", + ) + + self.assertTrue(handled) + self.assertIsNone(request.awaiting_input) + add_source.assert_called_once_with("acme/custom-skills") + post_message.assert_called_once() + render.assert_called_once() + + def test_skills_chain_text_removes_custom_market_source_by_index(self): + chain = SkillsChain() + request = skills_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + ) + + with patch.object( + chain, + "_remove_market_source", + return_value=(True, "已删除技能源:仓库来源 · acme/custom-skills"), + ) as remove_source, patch.object(chain, "_render_interaction") as render, patch.object( + chain, "post_message" + ) as post_message: + handled = chain.handle_text_interaction( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="删除源 3", + ) + + self.assertTrue(handled) + self.assertEqual(request.view, "sources") + remove_source.assert_called_once_with(page_index=3) + post_message.assert_called_once() + render.assert_called_once() + + def test_skills_chain_source_view_lists_custom_sources(self): + chain = SkillsChain() + request = skills_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + ) + request.view = "sources" + + with patch.object( + chain.skillhelper, + "list_market_source_entries", + return_value=[ + SkillMarketSource( + source="https://clawhub.ai", + label="社区注册表 · ClawHub", + builtin=True, + removable=False, + ), + SkillMarketSource( + source="https://github.com/acme/custom-skills", + label="仓库来源 · acme/custom-skills", + builtin=False, + removable=True, + ), + ], + ): + title, text, buttons = chain._build_sources_view(request=request) + + self.assertEqual(title, "技能源管理") + self.assertIn("社区注册表 · ClawHub", text) + self.assertIn("仓库来源 · acme/custom-skills", text) + self.assertIn("删除自定义源", text) + self.assertTrue(buttons) + self.assertEqual( + buttons[1][0]["callback_data"], + f"skills:{request.request_id}:source-remove:2", + ) + def test_skills_chain_updates_buttons_via_edit_message(self): chain = SkillsChain() buttons = [[{"text": "安装 1", "callback_data": "skills:req:install:1"}]]