mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-02-12 06:55:46 +08:00
Merge pull request #2656 from InfinityPacer/dev
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import gzip
|
||||
import json
|
||||
from hashlib import md5
|
||||
from typing import Annotated, Callable
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Annotated, Callable, Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
|
||||
from fastapi.responses import PlainTextResponse
|
||||
@@ -11,7 +9,7 @@ from fastapi.routing import APIRoute
|
||||
from app import schemas
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.common import decrypt
|
||||
from app.utils.crypto import CryptoJsUtils, HashUtils
|
||||
|
||||
|
||||
class GzipRequest(Request):
|
||||
@@ -47,7 +45,7 @@ async def verify_server_enabled():
|
||||
|
||||
|
||||
cookie_router = APIRouter(route_class=GzipRoute,
|
||||
tags=['servcookie'],
|
||||
tags=["servcookie"],
|
||||
dependencies=[Depends(verify_server_enabled)])
|
||||
|
||||
|
||||
@@ -100,15 +98,14 @@ def get_decrypted_cookie_data(uuid: str, password: str,
|
||||
"""
|
||||
加载本地加密数据并解密为Cookie
|
||||
"""
|
||||
key_md5 = md5()
|
||||
key_md5.update((uuid + '-' + password).encode('utf-8'))
|
||||
aes_key = (key_md5.hexdigest()[:16]).encode('utf-8')
|
||||
combined_string = f"{uuid}-{password}"
|
||||
aes_key = HashUtils.md5(combined_string)[:16].encode("utf-8")
|
||||
|
||||
if encrypted:
|
||||
try:
|
||||
decrypted_data = decrypt(encrypted, aes_key).decode('utf-8')
|
||||
decrypted_data = CryptoJsUtils.decrypt(encrypted, aes_key).decode("utf-8")
|
||||
decrypted_data = json.loads(decrypted_data)
|
||||
if 'cookie_data' in decrypted_data:
|
||||
if "cookie_data" in decrypted_data:
|
||||
return decrypted_data
|
||||
except Exception as e:
|
||||
logger.error(f"解密Cookie数据失败:{str(e)}")
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
import json
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, Tuple, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.utils.common import decrypt
|
||||
from app.utils.crypto import CryptoJsUtils, HashUtils
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
|
||||
class CookieCloudHelper:
|
||||
_ignore_cookies: list = ["CookieAutoDeleteBrowsingDataCleanup", "CookieAutoDeleteCleaningDiscarded"]
|
||||
|
||||
def __init__(self):
|
||||
self._sync_setting()
|
||||
self.__sync_setting()
|
||||
self._req = RequestUtils(content_type="application/json")
|
||||
|
||||
def _sync_setting(self):
|
||||
self._server = settings.COOKIECLOUD_HOST
|
||||
self._key = settings.COOKIECLOUD_KEY
|
||||
self._password = settings.COOKIECLOUD_PASSWORD
|
||||
def __sync_setting(self):
|
||||
"""
|
||||
同步CookieCloud配置项
|
||||
"""
|
||||
self._server = UrlUtils.standardize_base_url(settings.COOKIECLOUD_HOST)
|
||||
self._key = StringUtils.safe_strip(settings.COOKIECLOUD_KEY)
|
||||
self._password = StringUtils.safe_strip(settings.COOKIECLOUD_PASSWORD)
|
||||
self._enable_local = settings.COOKIECLOUD_ENABLE_LOCAL
|
||||
self._local_path = settings.COOKIE_PATH
|
||||
|
||||
@@ -28,7 +31,7 @@ class CookieCloudHelper:
|
||||
:return: Cookie数据、错误信息
|
||||
"""
|
||||
# 更新为最新设置
|
||||
self._sync_setting()
|
||||
self.__sync_setting()
|
||||
|
||||
if ((not self._server and not self._enable_local)
|
||||
or not self._key
|
||||
@@ -37,11 +40,11 @@ class CookieCloudHelper:
|
||||
|
||||
if self._enable_local:
|
||||
# 开启本地服务时,从本地直接读取数据
|
||||
result = self._load_local_encrypt_data(self._key)
|
||||
result = self.__load_local_encrypt_data(self._key)
|
||||
if not result:
|
||||
return {}, "未从本地CookieCloud服务加载到cookie数据,请检查服务器设置、用户KEY及加密密码是否正确"
|
||||
else:
|
||||
req_url = "%s/get/%s" % (self._server, str(self._key).strip())
|
||||
req_url = UrlUtils.combine_url(host=self._server, path=f"get/{self._key}")
|
||||
ret = self._req.get_res(url=req_url)
|
||||
if ret and ret.status_code == 200:
|
||||
try:
|
||||
@@ -59,9 +62,9 @@ class CookieCloudHelper:
|
||||
if not encrypted:
|
||||
return {}, "未获取到cookie密文"
|
||||
else:
|
||||
crypt_key = self._get_crypt_key()
|
||||
crypt_key = self.__get_crypt_key()
|
||||
try:
|
||||
decrypted_data = decrypt(encrypted, crypt_key).decode('utf-8')
|
||||
decrypted_data = CryptoJsUtils.decrypt(encrypted, crypt_key).decode("utf-8")
|
||||
result = json.loads(decrypted_data)
|
||||
except Exception as e:
|
||||
return {}, "cookie解密失败:" + str(e)
|
||||
@@ -105,15 +108,17 @@ class CookieCloudHelper:
|
||||
ret_cookies[domain] = cookie_str
|
||||
return ret_cookies, ""
|
||||
|
||||
def _get_crypt_key(self) -> bytes:
|
||||
def __get_crypt_key(self) -> bytes:
|
||||
"""
|
||||
使用UUID和密码生成CookieCloud的加解密密钥
|
||||
"""
|
||||
md5_generator = md5()
|
||||
md5_generator.update((str(self._key).strip() + '-' + str(self._password).strip()).encode('utf-8'))
|
||||
return (md5_generator.hexdigest()[:16]).encode('utf-8')
|
||||
combined_string = f"{self._key}-{self._password}"
|
||||
return HashUtils.md5(combined_string)[:16].encode("utf-8")
|
||||
|
||||
def _load_local_encrypt_data(self, uuid: str) -> Dict[str, Any]:
|
||||
def __load_local_encrypt_data(self, uuid: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取本地CookieCloud数据
|
||||
"""
|
||||
file_path = self._local_path / f"{uuid}.json"
|
||||
# 检查文件是否存在
|
||||
if not file_path.exists():
|
||||
|
||||
@@ -11,10 +11,10 @@ from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
|
||||
class Emby:
|
||||
|
||||
_host: str = None
|
||||
_playhost: str = None
|
||||
_apikey: str = None
|
||||
@@ -26,10 +26,10 @@ class Emby:
|
||||
return
|
||||
self._host = host
|
||||
if self._host:
|
||||
self._host = RequestUtils.standardize_base_url(self._host)
|
||||
self._host = UrlUtils.standardize_base_url(self._host)
|
||||
self._playhost = play_host
|
||||
if self._playhost:
|
||||
self._playhost = RequestUtils.standardize_base_url(self._playhost)
|
||||
self._playhost = UrlUtils.standardize_base_url(self._playhost)
|
||||
self._apikey = apikey
|
||||
self.user = self.get_user(settings.SUPERUSER)
|
||||
self.folders = self.get_emby_folders()
|
||||
|
||||
@@ -8,10 +8,10 @@ from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
|
||||
class Jellyfin:
|
||||
|
||||
_host: str = None
|
||||
_apikey: str = None
|
||||
_playhost: str = None
|
||||
@@ -23,10 +23,10 @@ class Jellyfin:
|
||||
return
|
||||
self._host = host
|
||||
if self._host:
|
||||
self._host = RequestUtils.standardize_base_url(self._host)
|
||||
self._host = UrlUtils.standardize_base_url(self._host)
|
||||
self._playhost = play_host
|
||||
if self._playhost:
|
||||
self._playhost = RequestUtils.standardize_base_url(self._playhost)
|
||||
self._playhost = UrlUtils.standardize_base_url(self._playhost)
|
||||
self._apikey = apikey
|
||||
self.user = self.get_user(settings.SUPERUSER)
|
||||
self.serverid = self.get_server_id()
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
|
||||
class Plex:
|
||||
@@ -25,10 +26,10 @@ class Plex:
|
||||
return
|
||||
self._host = host
|
||||
if self._host:
|
||||
self._host = RequestUtils.standardize_base_url(self._host)
|
||||
self._host = UrlUtils.standardize_base_url(self._host)
|
||||
self._playhost = play_host
|
||||
if self._playhost:
|
||||
self._playhost = RequestUtils.standardize_base_url(self._playhost)
|
||||
self._playhost = UrlUtils.standardize_base_url(self._playhost)
|
||||
self._token = token
|
||||
if self._host and self._token:
|
||||
try:
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import base64
|
||||
import time
|
||||
from hashlib import md5
|
||||
from typing import Any
|
||||
|
||||
from Crypto import Random
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
from app.schemas import ImmediateException
|
||||
|
||||
|
||||
@@ -41,48 +36,3 @@ def retry(ExceptionToCheck: Any,
|
||||
return f_retry
|
||||
|
||||
return deco_retry
|
||||
|
||||
|
||||
def bytes_to_key(data: bytes, salt: bytes, output=48) -> bytes:
|
||||
# extended from https://gist.github.com/gsakkis/4546068
|
||||
assert len(salt) == 8, len(salt)
|
||||
data += salt
|
||||
key = md5(data).digest()
|
||||
final_key = key
|
||||
while len(final_key) < output:
|
||||
key = md5(key + data).digest()
|
||||
final_key += key
|
||||
return final_key[:output]
|
||||
|
||||
|
||||
def encrypt(message: bytes, passphrase: bytes) -> bytes:
|
||||
"""
|
||||
CryptoJS 加密原文
|
||||
|
||||
This is a modified copy of https://stackoverflow.com/questions/36762098/how-to-decrypt-password-from-javascript-cryptojs-aes-encryptpassword-passphras
|
||||
"""
|
||||
salt = Random.new().read(8)
|
||||
key_iv = bytes_to_key(passphrase, salt, 32 + 16)
|
||||
key = key_iv[:32]
|
||||
iv = key_iv[32:]
|
||||
aes = AES.new(key, AES.MODE_CBC, iv)
|
||||
length = 16 - (len(message) % 16)
|
||||
data = message + (chr(length) * length).encode()
|
||||
return base64.b64encode(b"Salted__" + salt + aes.encrypt(data))
|
||||
|
||||
|
||||
def decrypt(encrypted: str | bytes, passphrase: bytes) -> bytes:
|
||||
"""
|
||||
CryptoJS 解密密文
|
||||
|
||||
来源同encrypt
|
||||
"""
|
||||
encrypted = base64.b64decode(encrypted)
|
||||
assert encrypted[0:8] == b"Salted__"
|
||||
salt = encrypted[8:16]
|
||||
key_iv = bytes_to_key(passphrase, salt, 32 + 16)
|
||||
key = key_iv[:32]
|
||||
iv = key_iv[32:]
|
||||
aes = AES.new(key, AES.MODE_CBC, iv)
|
||||
data = aes.decrypt(encrypted[16:])
|
||||
return data[:-(data[-1] if type(data[-1]) == int else ord(data[-1]))]
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from hashlib import md5
|
||||
from typing import Union
|
||||
|
||||
from Crypto import Random
|
||||
from Crypto.Cipher import AES
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization, hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding as asym_padding
|
||||
|
||||
|
||||
class RSAUtils:
|
||||
|
||||
@staticmethod
|
||||
def generate_rsa_key_pair() -> (str, str):
|
||||
def generate_rsa_key_pair(key_size: int = 2048) -> (str, str):
|
||||
"""
|
||||
生成RSA密钥对并返回Base64编码的公钥和私钥(DER格式)
|
||||
:return: Tuple containing Base64 encoded public key and private key
|
||||
生成RSA密钥对
|
||||
:return: 私钥和公钥(Base64 编码,无标识符)
|
||||
"""
|
||||
# 生成RSA密钥对
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
key_size=key_size,
|
||||
)
|
||||
|
||||
public_key = private_key.public_key()
|
||||
@@ -35,15 +40,16 @@ class RSAUtils:
|
||||
)
|
||||
|
||||
# 将DER格式的密钥编码为Base64
|
||||
private_key_b64 = base64.b64encode(private_key_der).decode('utf-8')
|
||||
public_key_b64 = base64.b64encode(public_key_der).decode('utf-8')
|
||||
private_key_b64 = base64.b64encode(private_key_der).decode("utf-8")
|
||||
public_key_b64 = base64.b64encode(public_key_der).decode("utf-8")
|
||||
|
||||
return private_key_b64, public_key_b64
|
||||
|
||||
@staticmethod
|
||||
def verify_rsa_keys(private_key: str, public_key: str) -> bool:
|
||||
"""
|
||||
使用 RSA 验证公钥和私钥是否匹配
|
||||
使用 RSA 验证私钥和公钥是否匹配
|
||||
|
||||
:param private_key: 私钥字符串 (Base64 编码,无标识符)
|
||||
:param public_key: 公钥字符串 (Base64 编码,无标识符)
|
||||
:return: 如果匹配则返回 True,否则返回 False
|
||||
@@ -67,8 +73,8 @@ class RSAUtils:
|
||||
message = b'test'
|
||||
encrypted_message = public_key.encrypt(
|
||||
message,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
asym_padding.OAEP(
|
||||
mgf=asym_padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
@@ -76,8 +82,8 @@ class RSAUtils:
|
||||
|
||||
decrypted_message = private_key.decrypt(
|
||||
encrypted_message,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
asym_padding.OAEP(
|
||||
mgf=asym_padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
@@ -87,3 +93,101 @@ class RSAUtils:
|
||||
except Exception as e:
|
||||
print(f"RSA 密钥验证失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class HashUtils:
|
||||
@staticmethod
|
||||
def md5(data: str, encoding: str = "utf-8") -> str:
|
||||
"""
|
||||
生成数据的MD5哈希值,并以字符串形式返回
|
||||
|
||||
:param data: 输入的数据,类型为字符串
|
||||
:param encoding: 字符串编码类型,默认使用UTF-8
|
||||
:return: 生成的MD5哈希字符串
|
||||
"""
|
||||
encoded_data = data.encode(encoding)
|
||||
return hashlib.md5(encoded_data).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def md5_bytes(data: str, encoding: str = "utf-8") -> bytes:
|
||||
"""
|
||||
生成数据的MD5哈希值,并以字节形式返回
|
||||
|
||||
:param data: 输入的数据,类型为字符串
|
||||
:param encoding: 字符串编码类型,默认使用UTF-8
|
||||
:return: 生成的MD5哈希二进制数据
|
||||
"""
|
||||
encoded_data = data.encode(encoding)
|
||||
return hashlib.md5(encoded_data).digest()
|
||||
|
||||
|
||||
class CryptoJsUtils:
|
||||
|
||||
@staticmethod
|
||||
def bytes_to_key(data: bytes, salt: bytes, output=48) -> bytes:
|
||||
"""
|
||||
生成加密/解密所需的密钥和初始化向量 (IV)
|
||||
"""
|
||||
# extended from https://gist.github.com/gsakkis/4546068
|
||||
assert len(salt) == 8, len(salt)
|
||||
data += salt
|
||||
key = md5(data).digest()
|
||||
final_key = key
|
||||
while len(final_key) < output:
|
||||
key = md5(key + data).digest()
|
||||
final_key += key
|
||||
return final_key[:output]
|
||||
|
||||
@staticmethod
|
||||
def encrypt(message: bytes, passphrase: bytes) -> bytes:
|
||||
"""
|
||||
使用 CryptoJS 兼容的加密策略对消息进行加密
|
||||
"""
|
||||
# This is a modified copy of https://stackoverflow.com/questions/36762098/how-to-decrypt-password-from-javascript-cryptojs-aes-encryptpassword-passphras
|
||||
# 生成8字节的随机盐值
|
||||
salt = Random.new().read(8)
|
||||
# 通过密码短语和盐值生成密钥和IV
|
||||
key_iv = CryptoJsUtils.bytes_to_key(passphrase, salt, 32 + 16)
|
||||
key = key_iv[:32]
|
||||
iv = key_iv[32:]
|
||||
# 创建AES加密器(CBC模式)
|
||||
aes = AES.new(key, AES.MODE_CBC, iv)
|
||||
# 应用PKCS#7填充
|
||||
padding_length = 16 - (len(message) % 16)
|
||||
padding = bytes([padding_length] * padding_length)
|
||||
padded_message = message + padding
|
||||
# 加密消息
|
||||
encrypted = aes.encrypt(padded_message)
|
||||
# 构建加密数据格式:b"Salted__" + salt + encrypted_message
|
||||
salted_encrypted = b"Salted__" + salt + encrypted
|
||||
# 返回Base64编码的加密数据
|
||||
return base64.b64encode(salted_encrypted)
|
||||
|
||||
@staticmethod
|
||||
def decrypt(encrypted: Union[str, bytes], passphrase: bytes) -> bytes:
|
||||
"""
|
||||
使用 CryptoJS 兼容的解密策略对加密消息进行解密
|
||||
"""
|
||||
# 确保输入是字节类型
|
||||
if isinstance(encrypted, str):
|
||||
encrypted = encrypted.encode("utf-8")
|
||||
# Base64 解码
|
||||
encrypted = base64.b64decode(encrypted)
|
||||
# 检查前8字节是否为 "Salted__"
|
||||
assert encrypted.startswith(b"Salted__"), "Invalid encrypted data format"
|
||||
# 提取盐值
|
||||
salt = encrypted[8:16]
|
||||
# 通过密码短语和盐值生成密钥和IV
|
||||
key_iv = CryptoJsUtils.bytes_to_key(passphrase, salt, 32 + 16)
|
||||
key = key_iv[:32]
|
||||
iv = key_iv[32:]
|
||||
# 创建AES解密器(CBC模式)
|
||||
aes = AES.new(key, AES.MODE_CBC, iv)
|
||||
# 解密加密部分
|
||||
decrypted_padded = aes.decrypt(encrypted[16:])
|
||||
# 移除PKCS#7填充
|
||||
padding_length = decrypted_padded[-1]
|
||||
if isinstance(padding_length, str):
|
||||
padding_length = ord(padding_length)
|
||||
decrypted = decrypted_padded[:-padding_length]
|
||||
return decrypted
|
||||
|
||||
@@ -7,6 +7,7 @@ from requests import Session, Response
|
||||
from urllib3.exceptions import InsecureRequestWarning
|
||||
|
||||
from app.log import logger
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
urllib3.disable_warnings(InsecureRequestWarning)
|
||||
|
||||
@@ -253,7 +254,7 @@ class RequestUtils:
|
||||
return None
|
||||
if endpoint.startswith(("http://", "https://")):
|
||||
return endpoint
|
||||
host = RequestUtils.standardize_base_url(host)
|
||||
host = UrlUtils.standardize_base_url(host)
|
||||
return urljoin(host, endpoint) if host else endpoint
|
||||
|
||||
@staticmethod
|
||||
@@ -269,7 +270,7 @@ class RequestUtils:
|
||||
# 如果路径为空,则默认为 '/'
|
||||
if path is None:
|
||||
path = '/'
|
||||
host = RequestUtils.standardize_base_url(host)
|
||||
host = UrlUtils.standardize_base_url(host)
|
||||
# 使用 urljoin 合并 host 和 path
|
||||
url = urljoin(host, path)
|
||||
# 解析当前 URL 的组成部分
|
||||
|
||||
@@ -12,7 +12,6 @@ import dateutil.parser
|
||||
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
_special_domains = [
|
||||
'u2.dmhy.org',
|
||||
'pt.ecust.pp.ua',
|
||||
@@ -788,3 +787,11 @@ class StringUtils:
|
||||
return f'{diff_minutes}分钟'
|
||||
else:
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def safe_strip(value) -> Optional[str]:
|
||||
"""
|
||||
去除字符串两端的空白字符
|
||||
:return: 如果输入值不是 None,返回去除空白字符后的字符串,否则返回 None
|
||||
"""
|
||||
return value.strip() if value is not None else None
|
||||
|
||||
71
app/utils/url.py
Normal file
71
app/utils/url.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin, urlparse, parse_qs, urlencode, urlunparse
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UrlUtils:
|
||||
|
||||
@staticmethod
|
||||
def standardize_base_url(host: str) -> str:
|
||||
"""
|
||||
标准化提供的主机地址,确保它以http://或https://开头,并且以斜杠(/)结尾
|
||||
:param host: 提供的主机地址字符串
|
||||
:return: 标准化后的主机地址字符串
|
||||
"""
|
||||
if not host:
|
||||
return host
|
||||
if not host.endswith("/"):
|
||||
host += "/"
|
||||
if not host.startswith("http://") and not host.startswith("https://"):
|
||||
host = "http://" + host
|
||||
return host
|
||||
|
||||
@staticmethod
|
||||
def adapt_request_url(host: str, endpoint: str) -> Optional[str]:
|
||||
"""
|
||||
基于传入的host,适配请求的URL,确保每个请求的URL是完整的,用于在发送请求前自动处理和修正请求的URL。
|
||||
:param host: 主机头
|
||||
:param endpoint: 端点
|
||||
:return: 完整的请求URL字符串
|
||||
"""
|
||||
if not host and not endpoint:
|
||||
return None
|
||||
if endpoint.startswith(("http://", "https://")):
|
||||
return endpoint
|
||||
host = UrlUtils.standardize_base_url(host)
|
||||
return urljoin(host, endpoint) if host else endpoint
|
||||
|
||||
@staticmethod
|
||||
def combine_url(host: str, path: Optional[str] = None, query: Optional[dict] = None) -> Optional[str]:
|
||||
"""
|
||||
使用给定的主机头、路径和查询参数组合生成完整的URL。
|
||||
:param host: str, 主机头,例如 https://example.com
|
||||
:param path: Optional[str], 包含路径和可能已经包含的查询参数的端点,例如 /path/to/resource?current=1
|
||||
:param query: Optional[dict], 可选,额外的查询参数,例如 {"key": "value"}
|
||||
:return: str, 完整的请求URL字符串
|
||||
"""
|
||||
try:
|
||||
# 如果路径为空,则默认为 '/'
|
||||
if path is None:
|
||||
path = '/'
|
||||
host = UrlUtils.standardize_base_url(host)
|
||||
# 使用 urljoin 合并 host 和 path
|
||||
url = urljoin(host, path)
|
||||
# 解析当前 URL 的组成部分
|
||||
url_parts = urlparse(url)
|
||||
# 解析已存在的查询参数,并与额外的查询参数合并
|
||||
query_params = parse_qs(url_parts.query)
|
||||
if query:
|
||||
for key, value in query.items():
|
||||
query_params[key] = value
|
||||
|
||||
# 重新构建查询字符串
|
||||
query_string = urlencode(query_params, doseq=True)
|
||||
# 构建完整的 URL
|
||||
new_url_parts = url_parts._replace(query=query_string)
|
||||
complete_url = urlunparse(new_url_parts)
|
||||
return str(complete_url)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error combining URL: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user