diff --git a/app/api/servcookie.py b/app/api/servcookie.py index ffae4147..30242b12 100644 --- a/app/api/servcookie.py +++ b/app/api/servcookie.py @@ -1,8 +1,6 @@ import gzip import json -from hashlib import md5 -from typing import Annotated, Callable -from typing import Any, Dict, Optional +from typing import Annotated, Callable, Any, Dict, Optional from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response from fastapi.responses import PlainTextResponse @@ -11,7 +9,7 @@ from fastapi.routing import APIRoute from app import schemas from app.core.config import settings from app.log import logger -from app.utils.common import decrypt +from app.utils.crypto import CryptoJsUtils, HashUtils class GzipRequest(Request): @@ -47,7 +45,7 @@ async def verify_server_enabled(): cookie_router = APIRouter(route_class=GzipRoute, - tags=['servcookie'], + tags=["servcookie"], dependencies=[Depends(verify_server_enabled)]) @@ -100,15 +98,14 @@ def get_decrypted_cookie_data(uuid: str, password: str, """ 加载本地加密数据并解密为Cookie """ - key_md5 = md5() - key_md5.update((uuid + '-' + password).encode('utf-8')) - aes_key = (key_md5.hexdigest()[:16]).encode('utf-8') + combined_string = f"{uuid}-{password}" + aes_key = HashUtils.md5(combined_string)[:16].encode("utf-8") if encrypted: try: - decrypted_data = decrypt(encrypted, aes_key).decode('utf-8') + decrypted_data = CryptoJsUtils.decrypt(encrypted, aes_key).decode("utf-8") decrypted_data = json.loads(decrypted_data) - if 'cookie_data' in decrypted_data: + if "cookie_data" in decrypted_data: return decrypted_data except Exception as e: logger.error(f"解密Cookie数据失败:{str(e)}") diff --git a/app/helper/cookiecloud.py b/app/helper/cookiecloud.py index 4fbabd81..0be8be5e 100644 --- a/app/helper/cookiecloud.py +++ b/app/helper/cookiecloud.py @@ -1,24 +1,27 @@ import json -from hashlib import md5 from typing import Any, Dict, Tuple, Optional from app.core.config import settings -from app.utils.common import decrypt +from app.utils.crypto import CryptoJsUtils, HashUtils from app.utils.http import RequestUtils from app.utils.string import StringUtils +from app.utils.url import UrlUtils class CookieCloudHelper: _ignore_cookies: list = ["CookieAutoDeleteBrowsingDataCleanup", "CookieAutoDeleteCleaningDiscarded"] def __init__(self): - self._sync_setting() + self.__sync_setting() self._req = RequestUtils(content_type="application/json") - def _sync_setting(self): - self._server = settings.COOKIECLOUD_HOST - self._key = settings.COOKIECLOUD_KEY - self._password = settings.COOKIECLOUD_PASSWORD + def __sync_setting(self): + """ + 同步CookieCloud配置项 + """ + self._server = UrlUtils.standardize_base_url(settings.COOKIECLOUD_HOST) + self._key = StringUtils.safe_strip(settings.COOKIECLOUD_KEY) + self._password = StringUtils.safe_strip(settings.COOKIECLOUD_PASSWORD) self._enable_local = settings.COOKIECLOUD_ENABLE_LOCAL self._local_path = settings.COOKIE_PATH @@ -28,7 +31,7 @@ class CookieCloudHelper: :return: Cookie数据、错误信息 """ # 更新为最新设置 - self._sync_setting() + self.__sync_setting() if ((not self._server and not self._enable_local) or not self._key @@ -37,11 +40,11 @@ class CookieCloudHelper: if self._enable_local: # 开启本地服务时,从本地直接读取数据 - result = self._load_local_encrypt_data(self._key) + result = self.__load_local_encrypt_data(self._key) if not result: return {}, "未从本地CookieCloud服务加载到cookie数据,请检查服务器设置、用户KEY及加密密码是否正确" else: - req_url = "%s/get/%s" % (self._server, str(self._key).strip()) + req_url = UrlUtils.combine_url(host=self._server, path=f"get/{self._key}") ret = self._req.get_res(url=req_url) if ret and ret.status_code == 200: try: @@ -59,9 +62,9 @@ class CookieCloudHelper: if not encrypted: return {}, "未获取到cookie密文" else: - crypt_key = self._get_crypt_key() + crypt_key = self.__get_crypt_key() try: - decrypted_data = decrypt(encrypted, crypt_key).decode('utf-8') + decrypted_data = CryptoJsUtils.decrypt(encrypted, crypt_key).decode("utf-8") result = json.loads(decrypted_data) except Exception as e: return {}, "cookie解密失败:" + str(e) @@ -105,15 +108,17 @@ class CookieCloudHelper: ret_cookies[domain] = cookie_str return ret_cookies, "" - def _get_crypt_key(self) -> bytes: + def __get_crypt_key(self) -> bytes: """ 使用UUID和密码生成CookieCloud的加解密密钥 """ - md5_generator = md5() - md5_generator.update((str(self._key).strip() + '-' + str(self._password).strip()).encode('utf-8')) - return (md5_generator.hexdigest()[:16]).encode('utf-8') + combined_string = f"{self._key}-{self._password}" + return HashUtils.md5(combined_string)[:16].encode("utf-8") - def _load_local_encrypt_data(self, uuid: str) -> Dict[str, Any]: + def __load_local_encrypt_data(self, uuid: str) -> Dict[str, Any]: + """ + 获取本地CookieCloud数据 + """ file_path = self._local_path / f"{uuid}.json" # 检查文件是否存在 if not file_path.exists(): diff --git a/app/modules/emby/emby.py b/app/modules/emby/emby.py index 24c4f896..c95cd79f 100644 --- a/app/modules/emby/emby.py +++ b/app/modules/emby/emby.py @@ -11,10 +11,10 @@ from app.core.config import settings from app.log import logger from app.schemas.types import MediaType from app.utils.http import RequestUtils +from app.utils.url import UrlUtils class Emby: - _host: str = None _playhost: str = None _apikey: str = None @@ -26,10 +26,10 @@ class Emby: return self._host = host if self._host: - self._host = RequestUtils.standardize_base_url(self._host) + self._host = UrlUtils.standardize_base_url(self._host) self._playhost = play_host if self._playhost: - self._playhost = RequestUtils.standardize_base_url(self._playhost) + self._playhost = UrlUtils.standardize_base_url(self._playhost) self._apikey = apikey self.user = self.get_user(settings.SUPERUSER) self.folders = self.get_emby_folders() diff --git a/app/modules/jellyfin/jellyfin.py b/app/modules/jellyfin/jellyfin.py index 6fecde46..709c5ee8 100644 --- a/app/modules/jellyfin/jellyfin.py +++ b/app/modules/jellyfin/jellyfin.py @@ -8,10 +8,10 @@ from app.core.config import settings from app.log import logger from app.schemas import MediaType from app.utils.http import RequestUtils +from app.utils.url import UrlUtils class Jellyfin: - _host: str = None _apikey: str = None _playhost: str = None @@ -23,10 +23,10 @@ class Jellyfin: return self._host = host if self._host: - self._host = RequestUtils.standardize_base_url(self._host) + self._host = UrlUtils.standardize_base_url(self._host) self._playhost = play_host if self._playhost: - self._playhost = RequestUtils.standardize_base_url(self._playhost) + self._playhost = UrlUtils.standardize_base_url(self._playhost) self._apikey = apikey self.user = self.get_user(settings.SUPERUSER) self.serverid = self.get_server_id() diff --git a/app/modules/plex/plex.py b/app/modules/plex/plex.py index e261db3f..f7c2346c 100644 --- a/app/modules/plex/plex.py +++ b/app/modules/plex/plex.py @@ -13,6 +13,7 @@ from app.core.config import settings from app.log import logger from app.schemas import MediaType from app.utils.http import RequestUtils +from app.utils.url import UrlUtils class Plex: @@ -25,10 +26,10 @@ class Plex: return self._host = host if self._host: - self._host = RequestUtils.standardize_base_url(self._host) + self._host = UrlUtils.standardize_base_url(self._host) self._playhost = play_host if self._playhost: - self._playhost = RequestUtils.standardize_base_url(self._playhost) + self._playhost = UrlUtils.standardize_base_url(self._playhost) self._token = token if self._host and self._token: try: diff --git a/app/utils/common.py b/app/utils/common.py index 5980b3b7..4ba1aa7e 100644 --- a/app/utils/common.py +++ b/app/utils/common.py @@ -1,11 +1,6 @@ -import base64 import time -from hashlib import md5 from typing import Any -from Crypto import Random -from Crypto.Cipher import AES - from app.schemas import ImmediateException @@ -41,48 +36,3 @@ def retry(ExceptionToCheck: Any, return f_retry return deco_retry - - -def bytes_to_key(data: bytes, salt: bytes, output=48) -> bytes: - # extended from https://gist.github.com/gsakkis/4546068 - assert len(salt) == 8, len(salt) - data += salt - key = md5(data).digest() - final_key = key - while len(final_key) < output: - key = md5(key + data).digest() - final_key += key - return final_key[:output] - - -def encrypt(message: bytes, passphrase: bytes) -> bytes: - """ - CryptoJS 加密原文 - - This is a modified copy of https://stackoverflow.com/questions/36762098/how-to-decrypt-password-from-javascript-cryptojs-aes-encryptpassword-passphras - """ - salt = Random.new().read(8) - key_iv = bytes_to_key(passphrase, salt, 32 + 16) - key = key_iv[:32] - iv = key_iv[32:] - aes = AES.new(key, AES.MODE_CBC, iv) - length = 16 - (len(message) % 16) - data = message + (chr(length) * length).encode() - return base64.b64encode(b"Salted__" + salt + aes.encrypt(data)) - - -def decrypt(encrypted: str | bytes, passphrase: bytes) -> bytes: - """ - CryptoJS 解密密文 - - 来源同encrypt - """ - encrypted = base64.b64decode(encrypted) - assert encrypted[0:8] == b"Salted__" - salt = encrypted[8:16] - key_iv = bytes_to_key(passphrase, salt, 32 + 16) - key = key_iv[:32] - iv = key_iv[32:] - aes = AES.new(key, AES.MODE_CBC, iv) - data = aes.decrypt(encrypted[16:]) - return data[:-(data[-1] if type(data[-1]) == int else ord(data[-1]))] diff --git a/app/utils/crypto.py b/app/utils/crypto.py index 2bc9b99f..b6ad5690 100644 --- a/app/utils/crypto.py +++ b/app/utils/crypto.py @@ -1,22 +1,27 @@ import base64 +import hashlib +from hashlib import md5 +from typing import Union +from Crypto import Random +from Crypto.Cipher import AES from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization, hashes -from cryptography.hazmat.primitives.asymmetric import rsa, padding +from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding class RSAUtils: @staticmethod - def generate_rsa_key_pair() -> (str, str): + def generate_rsa_key_pair(key_size: int = 2048) -> (str, str): """ - 生成RSA密钥对并返回Base64编码的公钥和私钥(DER格式) - :return: Tuple containing Base64 encoded public key and private key + 生成RSA密钥对 + :return: 私钥和公钥(Base64 编码,无标识符) """ # 生成RSA密钥对 private_key = rsa.generate_private_key( public_exponent=65537, - key_size=2048, + key_size=key_size, ) public_key = private_key.public_key() @@ -35,15 +40,16 @@ class RSAUtils: ) # 将DER格式的密钥编码为Base64 - private_key_b64 = base64.b64encode(private_key_der).decode('utf-8') - public_key_b64 = base64.b64encode(public_key_der).decode('utf-8') + private_key_b64 = base64.b64encode(private_key_der).decode("utf-8") + public_key_b64 = base64.b64encode(public_key_der).decode("utf-8") return private_key_b64, public_key_b64 @staticmethod def verify_rsa_keys(private_key: str, public_key: str) -> bool: """ - 使用 RSA 验证公钥和私钥是否匹配 + 使用 RSA 验证私钥和公钥是否匹配 + :param private_key: 私钥字符串 (Base64 编码,无标识符) :param public_key: 公钥字符串 (Base64 编码,无标识符) :return: 如果匹配则返回 True,否则返回 False @@ -67,8 +73,8 @@ class RSAUtils: message = b'test' encrypted_message = public_key.encrypt( message, - padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA256()), + asym_padding.OAEP( + mgf=asym_padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) @@ -76,8 +82,8 @@ class RSAUtils: decrypted_message = private_key.decrypt( encrypted_message, - padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA256()), + asym_padding.OAEP( + mgf=asym_padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) @@ -87,3 +93,101 @@ class RSAUtils: except Exception as e: print(f"RSA 密钥验证失败: {e}") return False + + +class HashUtils: + @staticmethod + def md5(data: str, encoding: str = "utf-8") -> str: + """ + 生成数据的MD5哈希值,并以字符串形式返回 + + :param data: 输入的数据,类型为字符串 + :param encoding: 字符串编码类型,默认使用UTF-8 + :return: 生成的MD5哈希字符串 + """ + encoded_data = data.encode(encoding) + return hashlib.md5(encoded_data).hexdigest() + + @staticmethod + def md5_bytes(data: str, encoding: str = "utf-8") -> bytes: + """ + 生成数据的MD5哈希值,并以字节形式返回 + + :param data: 输入的数据,类型为字符串 + :param encoding: 字符串编码类型,默认使用UTF-8 + :return: 生成的MD5哈希二进制数据 + """ + encoded_data = data.encode(encoding) + return hashlib.md5(encoded_data).digest() + + +class CryptoJsUtils: + + @staticmethod + def bytes_to_key(data: bytes, salt: bytes, output=48) -> bytes: + """ + 生成加密/解密所需的密钥和初始化向量 (IV) + """ + # extended from https://gist.github.com/gsakkis/4546068 + assert len(salt) == 8, len(salt) + data += salt + key = md5(data).digest() + final_key = key + while len(final_key) < output: + key = md5(key + data).digest() + final_key += key + return final_key[:output] + + @staticmethod + def encrypt(message: bytes, passphrase: bytes) -> bytes: + """ + 使用 CryptoJS 兼容的加密策略对消息进行加密 + """ + # This is a modified copy of https://stackoverflow.com/questions/36762098/how-to-decrypt-password-from-javascript-cryptojs-aes-encryptpassword-passphras + # 生成8字节的随机盐值 + salt = Random.new().read(8) + # 通过密码短语和盐值生成密钥和IV + key_iv = CryptoJsUtils.bytes_to_key(passphrase, salt, 32 + 16) + key = key_iv[:32] + iv = key_iv[32:] + # 创建AES加密器(CBC模式) + aes = AES.new(key, AES.MODE_CBC, iv) + # 应用PKCS#7填充 + padding_length = 16 - (len(message) % 16) + padding = bytes([padding_length] * padding_length) + padded_message = message + padding + # 加密消息 + encrypted = aes.encrypt(padded_message) + # 构建加密数据格式:b"Salted__" + salt + encrypted_message + salted_encrypted = b"Salted__" + salt + encrypted + # 返回Base64编码的加密数据 + return base64.b64encode(salted_encrypted) + + @staticmethod + def decrypt(encrypted: Union[str, bytes], passphrase: bytes) -> bytes: + """ + 使用 CryptoJS 兼容的解密策略对加密消息进行解密 + """ + # 确保输入是字节类型 + if isinstance(encrypted, str): + encrypted = encrypted.encode("utf-8") + # Base64 解码 + encrypted = base64.b64decode(encrypted) + # 检查前8字节是否为 "Salted__" + assert encrypted.startswith(b"Salted__"), "Invalid encrypted data format" + # 提取盐值 + salt = encrypted[8:16] + # 通过密码短语和盐值生成密钥和IV + key_iv = CryptoJsUtils.bytes_to_key(passphrase, salt, 32 + 16) + key = key_iv[:32] + iv = key_iv[32:] + # 创建AES解密器(CBC模式) + aes = AES.new(key, AES.MODE_CBC, iv) + # 解密加密部分 + decrypted_padded = aes.decrypt(encrypted[16:]) + # 移除PKCS#7填充 + padding_length = decrypted_padded[-1] + if isinstance(padding_length, str): + padding_length = ord(padding_length) + decrypted = decrypted_padded[:-padding_length] + return decrypted diff --git a/app/utils/http.py b/app/utils/http.py index bdef4272..b618c7ec 100644 --- a/app/utils/http.py +++ b/app/utils/http.py @@ -7,6 +7,7 @@ from requests import Session, Response from urllib3.exceptions import InsecureRequestWarning from app.log import logger +from app.utils.url import UrlUtils urllib3.disable_warnings(InsecureRequestWarning) @@ -253,7 +254,7 @@ class RequestUtils: return None if endpoint.startswith(("http://", "https://")): return endpoint - host = RequestUtils.standardize_base_url(host) + host = UrlUtils.standardize_base_url(host) return urljoin(host, endpoint) if host else endpoint @staticmethod @@ -269,7 +270,7 @@ class RequestUtils: # 如果路径为空,则默认为 '/' if path is None: path = '/' - host = RequestUtils.standardize_base_url(host) + host = UrlUtils.standardize_base_url(host) # 使用 urljoin 合并 host 和 path url = urljoin(host, path) # 解析当前 URL 的组成部分 diff --git a/app/utils/string.py b/app/utils/string.py index 9203713c..7f9c6f33 100644 --- a/app/utils/string.py +++ b/app/utils/string.py @@ -12,7 +12,6 @@ import dateutil.parser from app.schemas.types import MediaType - _special_domains = [ 'u2.dmhy.org', 'pt.ecust.pp.ua', @@ -788,3 +787,11 @@ class StringUtils: return f'{diff_minutes}分钟' else: return '' + + @staticmethod + def safe_strip(value) -> Optional[str]: + """ + 去除字符串两端的空白字符 + :return: 如果输入值不是 None,返回去除空白字符后的字符串,否则返回 None + """ + return value.strip() if value is not None else None diff --git a/app/utils/url.py b/app/utils/url.py new file mode 100644 index 00000000..39630ce7 --- /dev/null +++ b/app/utils/url.py @@ -0,0 +1,71 @@ +from typing import Optional +from urllib.parse import urljoin, urlparse, parse_qs, urlencode, urlunparse + +from app.log import logger + + +class UrlUtils: + + @staticmethod + def standardize_base_url(host: str) -> str: + """ + 标准化提供的主机地址,确保它以http://或https://开头,并且以斜杠(/)结尾 + :param host: 提供的主机地址字符串 + :return: 标准化后的主机地址字符串 + """ + if not host: + return host + if not host.endswith("/"): + host += "/" + if not host.startswith("http://") and not host.startswith("https://"): + host = "http://" + host + return host + + @staticmethod + def adapt_request_url(host: str, endpoint: str) -> Optional[str]: + """ + 基于传入的host,适配请求的URL,确保每个请求的URL是完整的,用于在发送请求前自动处理和修正请求的URL。 + :param host: 主机头 + :param endpoint: 端点 + :return: 完整的请求URL字符串 + """ + if not host and not endpoint: + return None + if endpoint.startswith(("http://", "https://")): + return endpoint + host = UrlUtils.standardize_base_url(host) + return urljoin(host, endpoint) if host else endpoint + + @staticmethod + def combine_url(host: str, path: Optional[str] = None, query: Optional[dict] = None) -> Optional[str]: + """ + 使用给定的主机头、路径和查询参数组合生成完整的URL。 + :param host: str, 主机头,例如 https://example.com + :param path: Optional[str], 包含路径和可能已经包含的查询参数的端点,例如 /path/to/resource?current=1 + :param query: Optional[dict], 可选,额外的查询参数,例如 {"key": "value"} + :return: str, 完整的请求URL字符串 + """ + try: + # 如果路径为空,则默认为 '/' + if path is None: + path = '/' + host = UrlUtils.standardize_base_url(host) + # 使用 urljoin 合并 host 和 path + url = urljoin(host, path) + # 解析当前 URL 的组成部分 + url_parts = urlparse(url) + # 解析已存在的查询参数,并与额外的查询参数合并 + query_params = parse_qs(url_parts.query) + if query: + for key, value in query.items(): + query_params[key] = value + + # 重新构建查询字符串 + query_string = urlencode(query_params, doseq=True) + # 构建完整的 URL + new_url_parts = url_parts._replace(query=query_string) + complete_url = urlunparse(new_url_parts) + return str(complete_url) + except Exception as e: + logger.debug(f"Error combining URL: {e}") + return None