Merge pull request #2656 from InfinityPacer/dev

This commit is contained in:
jxxghp
2024-08-26 10:08:31 +08:00
committed by GitHub
10 changed files with 236 additions and 100 deletions

View File

@@ -1,8 +1,6 @@
import gzip
import json
from hashlib import md5
from typing import Annotated, Callable
from typing import Any, Dict, Optional
from typing import Annotated, Callable, Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
from fastapi.responses import PlainTextResponse
@@ -11,7 +9,7 @@ from fastapi.routing import APIRoute
from app import schemas
from app.core.config import settings
from app.log import logger
from app.utils.common import decrypt
from app.utils.crypto import CryptoJsUtils, HashUtils
class GzipRequest(Request):
@@ -47,7 +45,7 @@ async def verify_server_enabled():
cookie_router = APIRouter(route_class=GzipRoute,
tags=['servcookie'],
tags=["servcookie"],
dependencies=[Depends(verify_server_enabled)])
@@ -100,15 +98,14 @@ def get_decrypted_cookie_data(uuid: str, password: str,
"""
加载本地加密数据并解密为Cookie
"""
key_md5 = md5()
key_md5.update((uuid + '-' + password).encode('utf-8'))
aes_key = (key_md5.hexdigest()[:16]).encode('utf-8')
combined_string = f"{uuid}-{password}"
aes_key = HashUtils.md5(combined_string)[:16].encode("utf-8")
if encrypted:
try:
decrypted_data = decrypt(encrypted, aes_key).decode('utf-8')
decrypted_data = CryptoJsUtils.decrypt(encrypted, aes_key).decode("utf-8")
decrypted_data = json.loads(decrypted_data)
if 'cookie_data' in decrypted_data:
if "cookie_data" in decrypted_data:
return decrypted_data
except Exception as e:
logger.error(f"解密Cookie数据失败{str(e)}")

View File

@@ -1,24 +1,27 @@
import json
from hashlib import md5
from typing import Any, Dict, Tuple, Optional
from app.core.config import settings
from app.utils.common import decrypt
from app.utils.crypto import CryptoJsUtils, HashUtils
from app.utils.http import RequestUtils
from app.utils.string import StringUtils
from app.utils.url import UrlUtils
class CookieCloudHelper:
_ignore_cookies: list = ["CookieAutoDeleteBrowsingDataCleanup", "CookieAutoDeleteCleaningDiscarded"]
def __init__(self):
self._sync_setting()
self.__sync_setting()
self._req = RequestUtils(content_type="application/json")
def _sync_setting(self):
self._server = settings.COOKIECLOUD_HOST
self._key = settings.COOKIECLOUD_KEY
self._password = settings.COOKIECLOUD_PASSWORD
def __sync_setting(self):
"""
同步CookieCloud配置项
"""
self._server = UrlUtils.standardize_base_url(settings.COOKIECLOUD_HOST)
self._key = StringUtils.safe_strip(settings.COOKIECLOUD_KEY)
self._password = StringUtils.safe_strip(settings.COOKIECLOUD_PASSWORD)
self._enable_local = settings.COOKIECLOUD_ENABLE_LOCAL
self._local_path = settings.COOKIE_PATH
@@ -28,7 +31,7 @@ class CookieCloudHelper:
:return: Cookie数据、错误信息
"""
# 更新为最新设置
self._sync_setting()
self.__sync_setting()
if ((not self._server and not self._enable_local)
or not self._key
@@ -37,11 +40,11 @@ class CookieCloudHelper:
if self._enable_local:
# 开启本地服务时,从本地直接读取数据
result = self._load_local_encrypt_data(self._key)
result = self.__load_local_encrypt_data(self._key)
if not result:
return {}, "未从本地CookieCloud服务加载到cookie数据请检查服务器设置、用户KEY及加密密码是否正确"
else:
req_url = "%s/get/%s" % (self._server, str(self._key).strip())
req_url = UrlUtils.combine_url(host=self._server, path=f"get/{self._key}")
ret = self._req.get_res(url=req_url)
if ret and ret.status_code == 200:
try:
@@ -59,9 +62,9 @@ class CookieCloudHelper:
if not encrypted:
return {}, "未获取到cookie密文"
else:
crypt_key = self._get_crypt_key()
crypt_key = self.__get_crypt_key()
try:
decrypted_data = decrypt(encrypted, crypt_key).decode('utf-8')
decrypted_data = CryptoJsUtils.decrypt(encrypted, crypt_key).decode("utf-8")
result = json.loads(decrypted_data)
except Exception as e:
return {}, "cookie解密失败" + str(e)
@@ -105,15 +108,17 @@ class CookieCloudHelper:
ret_cookies[domain] = cookie_str
return ret_cookies, ""
def _get_crypt_key(self) -> bytes:
def __get_crypt_key(self) -> bytes:
"""
使用UUID和密码生成CookieCloud的加解密密钥
"""
md5_generator = md5()
md5_generator.update((str(self._key).strip() + '-' + str(self._password).strip()).encode('utf-8'))
return (md5_generator.hexdigest()[:16]).encode('utf-8')
combined_string = f"{self._key}-{self._password}"
return HashUtils.md5(combined_string)[:16].encode("utf-8")
def _load_local_encrypt_data(self, uuid: str) -> Dict[str, Any]:
def __load_local_encrypt_data(self, uuid: str) -> Dict[str, Any]:
"""
获取本地CookieCloud数据
"""
file_path = self._local_path / f"{uuid}.json"
# 检查文件是否存在
if not file_path.exists():

View File

@@ -11,10 +11,10 @@ from app.core.config import settings
from app.log import logger
from app.schemas.types import MediaType
from app.utils.http import RequestUtils
from app.utils.url import UrlUtils
class Emby:
_host: str = None
_playhost: str = None
_apikey: str = None
@@ -26,10 +26,10 @@ class Emby:
return
self._host = host
if self._host:
self._host = RequestUtils.standardize_base_url(self._host)
self._host = UrlUtils.standardize_base_url(self._host)
self._playhost = play_host
if self._playhost:
self._playhost = RequestUtils.standardize_base_url(self._playhost)
self._playhost = UrlUtils.standardize_base_url(self._playhost)
self._apikey = apikey
self.user = self.get_user(settings.SUPERUSER)
self.folders = self.get_emby_folders()

View File

@@ -8,10 +8,10 @@ from app.core.config import settings
from app.log import logger
from app.schemas import MediaType
from app.utils.http import RequestUtils
from app.utils.url import UrlUtils
class Jellyfin:
_host: str = None
_apikey: str = None
_playhost: str = None
@@ -23,10 +23,10 @@ class Jellyfin:
return
self._host = host
if self._host:
self._host = RequestUtils.standardize_base_url(self._host)
self._host = UrlUtils.standardize_base_url(self._host)
self._playhost = play_host
if self._playhost:
self._playhost = RequestUtils.standardize_base_url(self._playhost)
self._playhost = UrlUtils.standardize_base_url(self._playhost)
self._apikey = apikey
self.user = self.get_user(settings.SUPERUSER)
self.serverid = self.get_server_id()

View File

@@ -13,6 +13,7 @@ from app.core.config import settings
from app.log import logger
from app.schemas import MediaType
from app.utils.http import RequestUtils
from app.utils.url import UrlUtils
class Plex:
@@ -25,10 +26,10 @@ class Plex:
return
self._host = host
if self._host:
self._host = RequestUtils.standardize_base_url(self._host)
self._host = UrlUtils.standardize_base_url(self._host)
self._playhost = play_host
if self._playhost:
self._playhost = RequestUtils.standardize_base_url(self._playhost)
self._playhost = UrlUtils.standardize_base_url(self._playhost)
self._token = token
if self._host and self._token:
try:

View File

@@ -1,11 +1,6 @@
import base64
import time
from hashlib import md5
from typing import Any
from Crypto import Random
from Crypto.Cipher import AES
from app.schemas import ImmediateException
@@ -41,48 +36,3 @@ def retry(ExceptionToCheck: Any,
return f_retry
return deco_retry
def bytes_to_key(data: bytes, salt: bytes, output=48) -> bytes:
# extended from https://gist.github.com/gsakkis/4546068
assert len(salt) == 8, len(salt)
data += salt
key = md5(data).digest()
final_key = key
while len(final_key) < output:
key = md5(key + data).digest()
final_key += key
return final_key[:output]
def encrypt(message: bytes, passphrase: bytes) -> bytes:
"""
CryptoJS 加密原文
This is a modified copy of https://stackoverflow.com/questions/36762098/how-to-decrypt-password-from-javascript-cryptojs-aes-encryptpassword-passphras
"""
salt = Random.new().read(8)
key_iv = bytes_to_key(passphrase, salt, 32 + 16)
key = key_iv[:32]
iv = key_iv[32:]
aes = AES.new(key, AES.MODE_CBC, iv)
length = 16 - (len(message) % 16)
data = message + (chr(length) * length).encode()
return base64.b64encode(b"Salted__" + salt + aes.encrypt(data))
def decrypt(encrypted: str | bytes, passphrase: bytes) -> bytes:
"""
CryptoJS 解密密文
来源同encrypt
"""
encrypted = base64.b64decode(encrypted)
assert encrypted[0:8] == b"Salted__"
salt = encrypted[8:16]
key_iv = bytes_to_key(passphrase, salt, 32 + 16)
key = key_iv[:32]
iv = key_iv[32:]
aes = AES.new(key, AES.MODE_CBC, iv)
data = aes.decrypt(encrypted[16:])
return data[:-(data[-1] if type(data[-1]) == int else ord(data[-1]))]

View File

@@ -1,22 +1,27 @@
import base64
import hashlib
from hashlib import md5
from typing import Union
from Crypto import Random
from Crypto.Cipher import AES
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding
class RSAUtils:
@staticmethod
def generate_rsa_key_pair() -> (str, str):
def generate_rsa_key_pair(key_size: int = 2048) -> (str, str):
"""
生成RSA密钥对并返回Base64编码的公钥和私钥DER格式
:return: Tuple containing Base64 encoded public key and private key
生成RSA密钥对
:return: 私钥和公钥Base64 编码,无标识符)
"""
# 生成RSA密钥对
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
key_size=key_size,
)
public_key = private_key.public_key()
@@ -35,15 +40,16 @@ class RSAUtils:
)
# 将DER格式的密钥编码为Base64
private_key_b64 = base64.b64encode(private_key_der).decode('utf-8')
public_key_b64 = base64.b64encode(public_key_der).decode('utf-8')
private_key_b64 = base64.b64encode(private_key_der).decode("utf-8")
public_key_b64 = base64.b64encode(public_key_der).decode("utf-8")
return private_key_b64, public_key_b64
@staticmethod
def verify_rsa_keys(private_key: str, public_key: str) -> bool:
"""
使用 RSA 验证钥和钥是否匹配
使用 RSA 验证钥和钥是否匹配
:param private_key: 私钥字符串 (Base64 编码,无标识符)
:param public_key: 公钥字符串 (Base64 编码,无标识符)
:return: 如果匹配则返回 True否则返回 False
@@ -67,8 +73,8 @@ class RSAUtils:
message = b'test'
encrypted_message = public_key.encrypt(
message,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
asym_padding.OAEP(
mgf=asym_padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
@@ -76,8 +82,8 @@ class RSAUtils:
decrypted_message = private_key.decrypt(
encrypted_message,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
asym_padding.OAEP(
mgf=asym_padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
@@ -87,3 +93,101 @@ class RSAUtils:
except Exception as e:
print(f"RSA 密钥验证失败: {e}")
return False
class HashUtils:
@staticmethod
def md5(data: str, encoding: str = "utf-8") -> str:
"""
生成数据的MD5哈希值并以字符串形式返回
:param data: 输入的数据,类型为字符串
:param encoding: 字符串编码类型默认使用UTF-8
:return: 生成的MD5哈希字符串
"""
encoded_data = data.encode(encoding)
return hashlib.md5(encoded_data).hexdigest()
@staticmethod
def md5_bytes(data: str, encoding: str = "utf-8") -> bytes:
"""
生成数据的MD5哈希值并以字节形式返回
:param data: 输入的数据,类型为字符串
:param encoding: 字符串编码类型默认使用UTF-8
:return: 生成的MD5哈希二进制数据
"""
encoded_data = data.encode(encoding)
return hashlib.md5(encoded_data).digest()
class CryptoJsUtils:
@staticmethod
def bytes_to_key(data: bytes, salt: bytes, output=48) -> bytes:
"""
生成加密/解密所需的密钥和初始化向量 (IV)
"""
# extended from https://gist.github.com/gsakkis/4546068
assert len(salt) == 8, len(salt)
data += salt
key = md5(data).digest()
final_key = key
while len(final_key) < output:
key = md5(key + data).digest()
final_key += key
return final_key[:output]
@staticmethod
def encrypt(message: bytes, passphrase: bytes) -> bytes:
"""
使用 CryptoJS 兼容的加密策略对消息进行加密
"""
# This is a modified copy of https://stackoverflow.com/questions/36762098/how-to-decrypt-password-from-javascript-cryptojs-aes-encryptpassword-passphras
# 生成8字节的随机盐值
salt = Random.new().read(8)
# 通过密码短语和盐值生成密钥和IV
key_iv = CryptoJsUtils.bytes_to_key(passphrase, salt, 32 + 16)
key = key_iv[:32]
iv = key_iv[32:]
# 创建AES加密器CBC模式
aes = AES.new(key, AES.MODE_CBC, iv)
# 应用PKCS#7填充
padding_length = 16 - (len(message) % 16)
padding = bytes([padding_length] * padding_length)
padded_message = message + padding
# 加密消息
encrypted = aes.encrypt(padded_message)
# 构建加密数据格式b"Salted__" + salt + encrypted_message
salted_encrypted = b"Salted__" + salt + encrypted
# 返回Base64编码的加密数据
return base64.b64encode(salted_encrypted)
@staticmethod
def decrypt(encrypted: Union[str, bytes], passphrase: bytes) -> bytes:
"""
使用 CryptoJS 兼容的解密策略对加密消息进行解密
"""
# 确保输入是字节类型
if isinstance(encrypted, str):
encrypted = encrypted.encode("utf-8")
# Base64 解码
encrypted = base64.b64decode(encrypted)
# 检查前8字节是否为 "Salted__"
assert encrypted.startswith(b"Salted__"), "Invalid encrypted data format"
# 提取盐值
salt = encrypted[8:16]
# 通过密码短语和盐值生成密钥和IV
key_iv = CryptoJsUtils.bytes_to_key(passphrase, salt, 32 + 16)
key = key_iv[:32]
iv = key_iv[32:]
# 创建AES解密器CBC模式
aes = AES.new(key, AES.MODE_CBC, iv)
# 解密加密部分
decrypted_padded = aes.decrypt(encrypted[16:])
# 移除PKCS#7填充
padding_length = decrypted_padded[-1]
if isinstance(padding_length, str):
padding_length = ord(padding_length)
decrypted = decrypted_padded[:-padding_length]
return decrypted

View File

@@ -7,6 +7,7 @@ from requests import Session, Response
from urllib3.exceptions import InsecureRequestWarning
from app.log import logger
from app.utils.url import UrlUtils
urllib3.disable_warnings(InsecureRequestWarning)
@@ -253,7 +254,7 @@ class RequestUtils:
return None
if endpoint.startswith(("http://", "https://")):
return endpoint
host = RequestUtils.standardize_base_url(host)
host = UrlUtils.standardize_base_url(host)
return urljoin(host, endpoint) if host else endpoint
@staticmethod
@@ -269,7 +270,7 @@ class RequestUtils:
# 如果路径为空,则默认为 '/'
if path is None:
path = '/'
host = RequestUtils.standardize_base_url(host)
host = UrlUtils.standardize_base_url(host)
# 使用 urljoin 合并 host 和 path
url = urljoin(host, path)
# 解析当前 URL 的组成部分

View File

@@ -12,7 +12,6 @@ import dateutil.parser
from app.schemas.types import MediaType
_special_domains = [
'u2.dmhy.org',
'pt.ecust.pp.ua',
@@ -788,3 +787,11 @@ class StringUtils:
return f'{diff_minutes}分钟'
else:
return ''
@staticmethod
def safe_strip(value) -> Optional[str]:
"""
去除字符串两端的空白字符
:return: 如果输入值不是 None返回去除空白字符后的字符串否则返回 None
"""
return value.strip() if value is not None else None

71
app/utils/url.py Normal file
View File

@@ -0,0 +1,71 @@
from typing import Optional
from urllib.parse import urljoin, urlparse, parse_qs, urlencode, urlunparse
from app.log import logger
class UrlUtils:
@staticmethod
def standardize_base_url(host: str) -> str:
"""
标准化提供的主机地址确保它以http://或https://开头,并且以斜杠(/)结尾
:param host: 提供的主机地址字符串
:return: 标准化后的主机地址字符串
"""
if not host:
return host
if not host.endswith("/"):
host += "/"
if not host.startswith("http://") and not host.startswith("https://"):
host = "http://" + host
return host
@staticmethod
def adapt_request_url(host: str, endpoint: str) -> Optional[str]:
"""
基于传入的host适配请求的URL确保每个请求的URL是完整的用于在发送请求前自动处理和修正请求的URL。
:param host: 主机头
:param endpoint: 端点
:return: 完整的请求URL字符串
"""
if not host and not endpoint:
return None
if endpoint.startswith(("http://", "https://")):
return endpoint
host = UrlUtils.standardize_base_url(host)
return urljoin(host, endpoint) if host else endpoint
@staticmethod
def combine_url(host: str, path: Optional[str] = None, query: Optional[dict] = None) -> Optional[str]:
"""
使用给定的主机头、路径和查询参数组合生成完整的URL。
:param host: str, 主机头,例如 https://example.com
:param path: Optional[str], 包含路径和可能已经包含的查询参数的端点,例如 /path/to/resource?current=1
:param query: Optional[dict], 可选,额外的查询参数,例如 {"key": "value"}
:return: str, 完整的请求URL字符串
"""
try:
# 如果路径为空,则默认为 '/'
if path is None:
path = '/'
host = UrlUtils.standardize_base_url(host)
# 使用 urljoin 合并 host 和 path
url = urljoin(host, path)
# 解析当前 URL 的组成部分
url_parts = urlparse(url)
# 解析已存在的查询参数,并与额外的查询参数合并
query_params = parse_qs(url_parts.query)
if query:
for key, value in query.items():
query_params[key] = value
# 重新构建查询字符串
query_string = urlencode(query_params, doseq=True)
# 构建完整的 URL
new_url_parts = url_parts._replace(query=query_string)
complete_url = urlunparse(new_url_parts)
return str(complete_url)
except Exception as e:
logger.debug(f"Error combining URL: {e}")
return None