From 4f5bd3a50ccc2949beb2fc78cdd39417672c35d6 Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Sun, 25 Jan 2026 17:50:16 +0800 Subject: [PATCH] feat(http): reuse connection --- dashscope/__init__.py | 280 ++++++++ dashscope/api_entities/api_request_factory.py | 8 + dashscope/api_entities/http_request.py | 215 ++++-- dashscope/common/aio_session_manager.py | 410 +++++++++++ dashscope/common/session_manager.py | 315 ++++++++ samples/test_generation.py | 5 + tests/unit/test_aio_connection_pool.py | 510 +++++++++++++ tests/unit/test_connection_pool.py | 675 ++++++++++++++++++ 8 files changed, 2371 insertions(+), 47 deletions(-) create mode 100644 dashscope/common/aio_session_manager.py create mode 100644 dashscope/common/session_manager.py create mode 100644 tests/unit/test_aio_connection_pool.py create mode 100644 tests/unit/test_connection_pool.py diff --git a/dashscope/__init__.py b/dashscope/__init__.py index 744b269..e439c13 100644 --- a/dashscope/__init__.py +++ b/dashscope/__init__.py @@ -25,6 +25,8 @@ base_http_api_url, base_websocket_api_url, ) +from dashscope.common.aio_session_manager import AioSessionManager +from dashscope.common.session_manager import SessionManager from dashscope.customize.deployments import Deployments from dashscope.customize.finetunes import FineTunes from dashscope.embeddings.batch_text_embedding import BatchTextEmbedding @@ -64,6 +66,276 @@ list_tokenizers, ) + +def enable_http_connection_pool( + pool_connections: int = None, + pool_maxsize: int = None, + max_retries: int = None, + pool_block: bool = None, +): + """ + 启用 HTTP 连接池复用 + + 启用后,所有同步 HTTP 请求将复用连接,显著减少延迟。 + + Args: + pool_connections: 连接池大小,默认 10 + - 低并发(< 10 req/s): 10 + - 中并发(10-50 req/s): 20-30 + - 高并发(> 50 req/s): 50-100 + + pool_maxsize: 最大连接数,默认 20 + - 应该 >= pool_connections + - 低并发: 20 + - 中并发: 50 + - 高并发: 100-200 + + max_retries: 重试次数,默认 3 + - 网络稳定: 3 + - 网络不稳定: 5-10 + + pool_block: 连接池满时是否阻塞,默认 False + - False: 连接池满时创建新连接(推荐) + - True: 连接池满时等待可用连接 + + Examples: + >>> import dashscope + >>> + >>> # 使用默认配置 + >>> dashscope.enable_http_connection_pool() + >>> + >>> # 自定义配置 + >>> dashscope.enable_http_connection_pool( + ... pool_connections=20, + ... pool_maxsize=50 + ... ) + >>> + >>> # 之后的所有请求都会复用连接 + >>> Generation.call(model='qwen-turbo', prompt='Hello') + """ + SessionManager.get_instance().enable( + pool_connections=pool_connections, + pool_maxsize=pool_maxsize, + max_retries=max_retries, + pool_block=pool_block, + ) + + +def disable_http_connection_pool(): + """ + 禁用 HTTP 连接池复用 + + 恢复到原有的每次请求创建新连接的行为。 + + Example: + >>> import dashscope + >>> dashscope.disable_http_connection_pool() + """ + SessionManager.get_instance().disable() + + +def reset_http_connection_pool(): + """ + 重置 HTTP 连接池 + + 用于处理连接问题或网络切换场景。 + + Example: + >>> import dashscope + >>> dashscope.reset_http_connection_pool() + """ + SessionManager.get_instance().reset() + + +def configure_http_connection_pool( + pool_connections: int = None, + pool_maxsize: int = None, + max_retries: int = None, + pool_block: bool = None, +): + """ + 配置 HTTP 连接池参数 + + 运行时动态调整连接池配置。 + + Args: + pool_connections: 连接池大小 + pool_maxsize: 最大连接数 + max_retries: 重试次数 + pool_block: 连接池满时是否阻塞 + + Examples: + >>> import dashscope + >>> + >>> # 调整单个参数 + >>> dashscope.configure_http_connection_pool(pool_maxsize=100) + >>> + >>> # 调整多个参数 + >>> dashscope.configure_http_connection_pool( + ... pool_connections=50, + ... pool_maxsize=100 + ... ) + """ + SessionManager.get_instance().configure( + pool_connections=pool_connections, + pool_maxsize=pool_maxsize, + max_retries=max_retries, + pool_block=pool_block, + ) + + +async def enable_aio_http_connection_pool( + limit: int = None, + limit_per_host: int = None, + ttl_dns_cache: int = None, + keepalive_timeout: int = None, + force_close: bool = None, +): + """ + 启用异步 HTTP 连接池复用 + + 启用后,所有异步 HTTP 请求将复用连接,显著减少延迟。 + + Args: + limit: 总连接数限制,默认 100 + - 低并发(< 10 req/s): 100 + - 中并发(10-50 req/s): 200 + - 高并发(> 50 req/s): 300-500 + + limit_per_host: 每个主机的连接数限制,默认 30 + - 应该 <= limit + - 低并发: 30 + - 中并发: 50 + - 高并发: 100 + + ttl_dns_cache: DNS 缓存 TTL(秒),默认 300 + - DNS 稳定: 300-600 + - DNS 变化频繁: 60-120 + + keepalive_timeout: Keep-Alive 超时(秒),默认 30 + - 短连接: 15-30 + - 长连接: 60-120 + + force_close: 是否强制关闭连接,默认 False + - False: 复用连接(推荐) + - True: 每次关闭连接 + + Examples: + >>> import asyncio + >>> import dashscope + >>> from dashscope import AioGeneration + >>> + >>> async def main(): + ... # 使用默认配置 + ... await dashscope.enable_aio_http_connection_pool() + ... + ... # 之后的所有异步请求都会复用连接 + ... response = await AioGeneration.call( + ... model='qwen-turbo', + ... prompt='Hello' + ... ) + ... + ... # 自定义配置 + ... await dashscope.enable_aio_http_connection_pool( + ... limit=200, + ... limit_per_host=50 + ... ) + >>> + >>> asyncio.run(main()) + """ + manager = await AioSessionManager.get_instance() + await manager.enable( + limit=limit, + limit_per_host=limit_per_host, + ttl_dns_cache=ttl_dns_cache, + keepalive_timeout=keepalive_timeout, + force_close=force_close, + ) + + +async def disable_aio_http_connection_pool(): + """ + 禁用异步 HTTP 连接池复用 + + 恢复到原有的每次请求创建新连接的行为。 + + Examples: + >>> import asyncio + >>> import dashscope + >>> + >>> async def main(): + ... await dashscope.disable_aio_http_connection_pool() + >>> + >>> asyncio.run(main()) + """ + manager = await AioSessionManager.get_instance() + await manager.disable() + + +async def reset_aio_http_connection_pool(): + """ + 重置异步 HTTP 连接池 + + 用于处理连接问题或网络切换场景。 + + Examples: + >>> import asyncio + >>> import dashscope + >>> + >>> async def main(): + ... await dashscope.reset_aio_http_connection_pool() + >>> + >>> asyncio.run(main()) + """ + manager = await AioSessionManager.get_instance() + await manager.reset() + + +async def configure_aio_http_connection_pool( + limit: int = None, + limit_per_host: int = None, + ttl_dns_cache: int = None, + keepalive_timeout: int = None, + force_close: bool = None, +): + """ + 配置异步 HTTP 连接池参数 + + 运行时动态调整连接池配置。 + + Args: + limit: 总连接数限制 + limit_per_host: 每个主机的连接数限制 + ttl_dns_cache: DNS 缓存 TTL(秒) + keepalive_timeout: Keep-Alive 超时(秒) + force_close: 是否强制关闭连接 + + Examples: + >>> import asyncio + >>> import dashscope + >>> + >>> async def main(): + ... # 调整单个参数 + ... await dashscope.configure_aio_http_connection_pool(limit=200) + ... + ... # 调整多个参数 + ... await dashscope.configure_aio_http_connection_pool( + ... limit=200, + ... limit_per_host=50 + ... ) + >>> + >>> asyncio.run(main()) + """ + manager = await AioSessionManager.get_instance() + await manager.configure( + limit=limit, + limit_per_host=limit_per_host, + ttl_dns_cache=ttl_dns_cache, + keepalive_timeout=keepalive_timeout, + force_close=force_close, + ) + + __all__ = [ "base_http_api_url", "base_websocket_api_url", @@ -118,6 +390,14 @@ "MessageFile", "AssistantFile", "VideoSynthesis", + "enable_http_connection_pool", + "disable_http_connection_pool", + "reset_http_connection_pool", + "configure_http_connection_pool", + "enable_aio_http_connection_pool", + "disable_aio_http_connection_pool", + "reset_aio_http_connection_pool", + "configure_aio_http_connection_pool", ] logging.getLogger(__name__).addHandler(NullHandler()) diff --git a/dashscope/api_entities/api_request_factory.py b/dashscope/api_entities/api_request_factory.py index bc96fc0..591966f 100644 --- a/dashscope/api_entities/api_request_factory.py +++ b/dashscope/api_entities/api_request_factory.py @@ -36,6 +36,8 @@ def _get_protocol_params(kwargs): base_address = kwargs.pop("base_address", None) flattened_output = kwargs.pop("flattened_output", False) extra_url_parameters = kwargs.pop("extra_url_parameters", None) + session = kwargs.pop("session", None) + aio_session = kwargs.pop("aio_session", None) # Extract user-agent from headers if present user_agent = "" @@ -58,6 +60,8 @@ def _get_protocol_params(kwargs): flattened_output, extra_url_parameters, user_agent, + session, + aio_session, ) @@ -87,6 +91,8 @@ def _build_api_request( # pylint: disable=too-many-branches flattened_output, extra_url_parameters, user_agent, + session, + aio_session, ) = _get_protocol_params(kwargs) task_id = kwargs.pop("task_id", None) enable_encryption = kwargs.pop("enable_encryption", False) @@ -130,6 +136,8 @@ def _build_api_request( # pylint: disable=too-many-branches flattened_output=flattened_output, encryption=encryption, user_agent=user_agent, + session=session, + aio_session=aio_session, ) elif api_protocol == ApiProtocol.WEBSOCKET: if base_address is not None: diff --git a/dashscope/api_entities/http_request.py b/dashscope/api_entities/http_request.py index d84dbc3..42343a7 100644 --- a/dashscope/api_entities/http_request.py +++ b/dashscope/api_entities/http_request.py @@ -42,6 +42,8 @@ def __init__( flattened_output: bool = False, encryption: Optional[Encryption] = None, user_agent: str = "", + session: Optional[requests.Session] = None, + aio_session: Optional[aiohttp.ClientSession] = None, ) -> None: """HttpSSERequest, processing http server sent event stream. @@ -54,6 +56,10 @@ def __init__( Defaults to DEFAULT_REQUEST_TIMEOUT_SECONDS. user_agent (str, optional): Additional user agent string to append. Defaults to ''. + session (Optional[requests.Session]): External session for + connection reuse (sync). Defaults to None. + aio_session (Optional[aiohttp.ClientSession]): External session + for connection reuse (async). Defaults to None. """ super().__init__(user_agent=user_agent) @@ -61,10 +67,13 @@ def __init__( self.flattened_output = flattened_output self.async_request = async_request self.encryption = encryption + self._external_session = session + self._external_aio_session = aio_session + base_headers = getattr(self, "headers", {}) self.headers = { "Accept": "application/json", "Authorization": f"Bearer {api_key}", - **self.headers, + **base_headers, } if encryption and encryption.is_valid(): @@ -102,6 +111,24 @@ def __init__( else: self.timeout = timeout # type: ignore[has-type] + def get_external_session(self) -> Optional[requests.Session]: + """ + 获取外部传入的同步 Session + + Returns: + Optional[requests.Session]: 外部 Session,如果未设置则返回 None + """ + return self._external_session + + def get_external_aio_session(self) -> Optional[aiohttp.ClientSession]: + """ + 获取外部传入的异步 Session + + Returns: + Optional[aiohttp.ClientSession]: 外部异步 Session,如果未设置则返回 None + """ + return self._external_aio_session + def add_header(self, key, value): self.headers[key] = value @@ -132,57 +159,119 @@ async def aio_call(self): pass return result + async def _get_aio_session(self): + """获取异步 Session(优先级:外部 > 全局 > 临时)""" + # 1. 检查是否有外部传入的 Session(最高优先级) + if self._external_aio_session is not None: + logger.debug( + "Using external async session for request: %s", + self.url, + ) + return self._external_aio_session, False + + # 2. 尝试获取全局异步 Session + from dashscope.common.aio_session_manager import AioSessionManager + + manager = await AioSessionManager.get_instance() + global_session = await manager.get_session() + + if global_session is not None: + logger.debug( + "Using global async session for request: %s", + self.url, + ) + return global_session, False + + # 3. 创建临时 Session(保持向后兼容) + connector = aiohttp.TCPConnector( + ssl=ssl.create_default_context( + cafile=certifi.where(), + ), + ) + session = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) + logger.debug( + "Using temporary async session for request: %s", + self.url, + ) + return session, True + + async def _execute_aio_request(self, session, timeout): + """执行异步 HTTP 请求""" + logger.debug("Starting request: %s", self.url) + + if self.method == HTTPMethod.POST: + return await self._execute_post_request(session, timeout) + if self.method == HTTPMethod.GET: + return await self._execute_get_request(session, timeout) + + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", + ) + + async def _execute_post_request(self, session, timeout): + """执行 POST 请求""" + is_form, obj = False, {} + if hasattr(self, "data") and self.data is not None: + is_form, obj = self.data.get_aiohttp_payload() + + if is_form: + headers = {**self.headers, **obj.headers} + return await session.post( + url=self.url, + data=obj, + headers=headers, + timeout=timeout, + ) + + return await session.request( + "POST", + url=self.url, + json=obj, + headers=self.headers, + timeout=timeout, + ) + + async def _execute_get_request(self, session, timeout): + """执行 GET 请求""" + params = {} + if hasattr(self, "data") and self.data is not None: + params = getattr(self.data, "parameters", {}) + if params: + params = self.__handle_parameters(params) + + return await session.get( + url=self.url, + params=params, + headers=self.headers, + timeout=timeout, + ) + async def _handle_aio_request(self): try: - connector = aiohttp.TCPConnector( - ssl=ssl.create_default_context( - cafile=certifi.where(), - ), - ) - async with aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=self.timeout), - headers=self.headers, - ) as session: - logger.debug("Starting request: %s", self.url) - if self.method == HTTPMethod.POST: - is_form, obj = False, {} - if hasattr(self, "data") and self.data is not None: - is_form, obj = self.data.get_aiohttp_payload() - if is_form: - headers = {**self.headers, **obj.headers} - response = await session.post( - url=self.url, - data=obj, - headers=headers, - ) - else: - response = await session.request( - "POST", - url=self.url, - json=obj, - headers=self.headers, - ) - elif self.method == HTTPMethod.GET: - # 添加条件判断 - params = {} - if hasattr(self, "data") and self.data is not None: - params = getattr(self.data, "parameters", {}) - if params: - params = self.__handle_parameters(params) - response = await session.get( - url=self.url, - params=params, - headers=self.headers, - ) - else: - raise UnsupportedHTTPMethod( - f"Unsupported http method: {self.method}", - ) + # 获取 Session(优先级:外部 > 全局 > 临时) + session, should_close = await self._get_aio_session() + + try: + # 设置超时 + timeout = aiohttp.ClientTimeout(total=self.timeout) + + # 执行请求 + response = await self._execute_aio_request(session, timeout) + logger.debug("Response returned: %s", self.url) async with response: async for rsp in self._handle_aio_response(response): yield rsp + finally: + # 只关闭临时 Session + if should_close: + await session.close() + logger.debug("Temporary async session closed") + except aiohttp.ClientConnectorError as e: logger.error(e) raise e @@ -407,8 +496,34 @@ def _handle_response( # pylint: disable=too-many-branches yield _handle_http_failed_response(response) def _handle_request(self): + """ + 处理 HTTP 请求 + + 优先级: + 1. 外部传入的 session(用户自定义) + 2. 全局 SessionManager(如果启用) + 3. 临时 session(保持原有行为) + """ try: - with requests.Session() as session: + from dashscope.common.session_manager import SessionManager + + # 优先使用外部传入的 session + if self._external_session is not None: + session = self._external_session + should_close = False + else: + # 尝试使用全局 SessionManager + session_manager = SessionManager.get_instance() + session = session_manager.get_session() + should_close = False + + # 如果未启用连接复用,创建临时 session + if session is None: + session = requests.Session() + should_close = True + + try: + # 执行请求 if self.method == HTTPMethod.POST: is_form, form, obj = self.data.get_http_payload() if is_form: @@ -441,8 +556,14 @@ def _handle_request(self): raise UnsupportedHTTPMethod( f"Unsupported http method: {self.method}", ) + for rsp in self._handle_response(response): yield rsp + finally: + # 只关闭临时创建的 session + if should_close: + session.close() + except BaseException as e: logger.error(e) raise e diff --git a/dashscope/common/aio_session_manager.py b/dashscope/common/aio_session_manager.py new file mode 100644 index 0000000..ceaa7a2 --- /dev/null +++ b/dashscope/common/aio_session_manager.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +"""异步 HTTP Session 管理器,用于管理 aiohttp.ClientSession 的连接池复用""" + +import asyncio +import ssl +from typing import Optional + +import aiohttp +import certifi + +from dashscope.common.logging import logger + + +class AioConnectionPoolConfig: + """异步连接池配置类""" + + def __init__( + self, + limit: int = 100, + limit_per_host: int = 30, + ttl_dns_cache: int = 300, + keepalive_timeout: int = 30, + force_close: bool = False, + ): + """ + 初始化异步连接池配置 + + Args: + limit: 总连接数限制,默认 100 + limit_per_host: 每个主机的连接数限制,默认 30 + ttl_dns_cache: DNS 缓存 TTL(秒),默认 300 + keepalive_timeout: Keep-Alive 超时(秒),默认 30 + force_close: 是否强制关闭连接,默认 False + + Raises: + ValueError: 当参数值不合法时 + """ + if limit <= 0: + raise ValueError(f"limit ({limit}) 必须 > 0") + if limit_per_host <= 0: + raise ValueError(f"limit_per_host ({limit_per_host}) 必须 > 0") + if limit_per_host > limit: + raise ValueError( + f"limit_per_host ({limit_per_host}) 必须 <= " f"limit ({limit})", + ) + if ttl_dns_cache < 0: + raise ValueError(f"ttl_dns_cache ({ttl_dns_cache}) 必须 >= 0") + if keepalive_timeout < 0: + raise ValueError( + f"keepalive_timeout ({keepalive_timeout}) 必须 >= 0", + ) + + self.limit = limit + self.limit_per_host = limit_per_host + self.ttl_dns_cache = ttl_dns_cache + self.keepalive_timeout = keepalive_timeout + self.force_close = force_close + + def __repr__(self): + return ( + f"AioConnectionPoolConfig(limit={self.limit}, " + f"limit_per_host={self.limit_per_host}, " + f"ttl_dns_cache={self.ttl_dns_cache}, " + f"keepalive_timeout={self.keepalive_timeout}, " + f"force_close={self.force_close})" + ) + + +class AioSessionManager: + """ + 异步 HTTP Session 管理器(单例模式) + + 用于管理全局的 aiohttp.ClientSession 实例,实现异步 HTTP 连接复用。 + + 特性: + - 单例模式:全局唯一实例 + - 异步锁保护:使用 asyncio.Lock 保护并发访问 + - 连接池配置:支持自定义 TCPConnector 参数 + - 生命周期管理:支持启用、禁用、重置 + - 向后兼容:默认禁用,不影响现有代码 + + Examples: + >>> import asyncio + >>> from dashscope.common.aio_session_manager import AioSessionManager + >>> + >>> async def main(): + ... manager = await AioSessionManager.get_instance() + ... await manager.enable(limit=200, limit_per_host=50) + ... session = await manager.get_session() + ... # 使用 session 进行请求 + ... await manager.disable() + >>> + >>> asyncio.run(main()) + """ + + _instance: Optional["AioSessionManager"] = None + _lock = asyncio.Lock() + + def __init__(self): + """初始化 Session 管理器(私有,通过 get_instance 获取)""" + self._enabled = False + self._session: Optional[aiohttp.ClientSession] = None + self._session_lock = asyncio.Lock() + self._config = AioConnectionPoolConfig() + logger.debug("AioSessionManager initialized") + + @classmethod + async def get_instance(cls) -> "AioSessionManager": + """ + 获取单例实例(异步) + + Returns: + AioSessionManager: 单例实例 + """ + if cls._instance is None: + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + logger.debug( + "AioSessionManager singleton instance created", + ) + return cls._instance + + @classmethod + async def reset_instance(cls): + """ + 重置单例实例(仅用于测试) + + 警告:此方法仅应在测试环境中使用 + """ + async with cls._lock: + if cls._instance is not None: + await cls._instance.disable() + await cls._instance.reset() + cls._instance = None + logger.debug("AioSessionManager singleton instance reset") + + async def enable( + self, + limit: int = None, + limit_per_host: int = None, + ttl_dns_cache: int = None, + keepalive_timeout: int = None, + force_close: bool = None, + ): + """ + 启用异步连接池复用 + + Args: + limit: 总连接数限制,默认 100 + limit_per_host: 每个主机的连接数限制,默认 30 + ttl_dns_cache: DNS 缓存 TTL(秒),默认 300 + keepalive_timeout: Keep-Alive 超时(秒),默认 30 + force_close: 是否强制关闭连接,默认 False + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> await manager.enable(limit=200, limit_per_host=50) + """ + async with self._session_lock: + # 如果提供了配置参数,先配置 + if any( + param is not None + for param in [ + limit, + limit_per_host, + ttl_dns_cache, + keepalive_timeout, + force_close, + ] + ): + await self._configure( + limit=limit, + limit_per_host=limit_per_host, + ttl_dns_cache=ttl_dns_cache, + keepalive_timeout=keepalive_timeout, + force_close=force_close, + ) + + self._enabled = True + await self._ensure_session() + logger.info( + "Async HTTP connection pool enabled with config: %s", + self._config, + ) + + async def disable(self): + """ + 禁用异步连接池复用 + + 关闭当前 Session 并禁用连接池功能 + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> await manager.disable() + """ + async with self._session_lock: + self._enabled = False + if self._session and not self._session.closed: + await self._session.close() + logger.debug("Async ClientSession closed") + self._session = None + logger.info("Async HTTP connection pool disabled") + + async def configure( + self, + limit: int = None, + limit_per_host: int = None, + ttl_dns_cache: int = None, + keepalive_timeout: int = None, + force_close: bool = None, + ): + """ + 配置连接池参数 + + Args: + limit: 总连接数限制 + limit_per_host: 每个主机的连接数限制 + ttl_dns_cache: DNS 缓存 TTL(秒) + keepalive_timeout: Keep-Alive 超时(秒) + force_close: 是否强制关闭连接 + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> await manager.configure(limit=200, limit_per_host=50) + """ + async with self._session_lock: + await self._configure( + limit=limit, + limit_per_host=limit_per_host, + ttl_dns_cache=ttl_dns_cache, + keepalive_timeout=keepalive_timeout, + force_close=force_close, + ) + + async def _configure( + self, + limit: int = None, + limit_per_host: int = None, + ttl_dns_cache: int = None, + keepalive_timeout: int = None, + force_close: bool = None, + ): + """内部配置方法(无锁)""" + config_params = {} + if limit is not None: + config_params["limit"] = limit + if limit_per_host is not None: + config_params["limit_per_host"] = limit_per_host + if ttl_dns_cache is not None: + config_params["ttl_dns_cache"] = ttl_dns_cache + if keepalive_timeout is not None: + config_params["keepalive_timeout"] = keepalive_timeout + if force_close is not None: + config_params["force_close"] = force_close + + if config_params: + # 创建新配置 + limit = config_params.get("limit", self._config.limit) + limit_per_host = config_params.get( + "limit_per_host", + self._config.limit_per_host, + ) + ttl_dns_cache = config_params.get( + "ttl_dns_cache", + self._config.ttl_dns_cache, + ) + keepalive_timeout = config_params.get( + "keepalive_timeout", + self._config.keepalive_timeout, + ) + force_close = config_params.get( + "force_close", + self._config.force_close, + ) + + new_config = AioConnectionPoolConfig( + limit=limit, + limit_per_host=limit_per_host, + ttl_dns_cache=ttl_dns_cache, + keepalive_timeout=keepalive_timeout, + force_close=bool(force_close), + ) + self._config = new_config + + # 如果已启用,重新创建 Session + if self._enabled: + if self._session and not self._session.closed: + await self._session.close() + self._session = None + await self._ensure_session() + logger.info( + "Async connection pool reconfigured: %s", + self._config, + ) + + async def _ensure_session(self): + """确保 Session 存在且有效(内部方法,无锁)""" + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector( + limit=self._config.limit, + limit_per_host=self._config.limit_per_host, + ttl_dns_cache=self._config.ttl_dns_cache, + keepalive_timeout=self._config.keepalive_timeout, + force_close=self._config.force_close, + ssl=ssl.create_default_context(cafile=certifi.where()), + ) + self._session = aiohttp.ClientSession(connector=connector) + logger.debug( + "New async ClientSession created with config: %s", + self._config, + ) + + async def get_session(self) -> Optional[aiohttp.ClientSession]: + """ + 获取 Session(如果启用) + + Returns: + Optional[aiohttp.ClientSession]: 如果启用返回 Session,否则返回 None + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> await manager.enable() + >>> session = await manager.get_session() + >>> if session: + ... # 使用 session 进行请求 + ... pass + """ + async with self._session_lock: + if self._enabled: + await self._ensure_session() + return self._session + return None + + async def get_session_direct(self) -> Optional[aiohttp.ClientSession]: + """ + 直接获取 Session(不检查启用状态) + + Returns: + Optional[aiohttp.ClientSession]: 当前 Session 或 None + + Note: + 此方法主要用于测试,一般应使用 get_session() + """ + async with self._session_lock: + return self._session + + async def reset(self): + """ + 重置 Session + + 关闭当前 Session 并根据启用状态重新创建 + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> await manager.reset() + """ + async with self._session_lock: + if self._session and not self._session.closed: + await self._session.close() + logger.debug("Async ClientSession closed during reset") + self._session = None + if self._enabled: + await self._ensure_session() + logger.info("Async HTTP connection pool reset") + + def get_config(self) -> AioConnectionPoolConfig: + """ + 获取当前连接池配置 + + Returns: + AioConnectionPoolConfig: 当前配置 + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> config = manager.get_config() + >>> print(config.limit) + """ + return self._config + + def is_enabled(self) -> bool: + """ + 检查连接池是否已启用 + + Returns: + bool: 是否已启用 + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> if manager.is_enabled(): + ... print("Connection pool is enabled") + """ + return self._enabled + + async def has_active_session(self) -> bool: + """ + 检查是否有活跃的 Session + + Returns: + bool: 是否有活跃的 Session + + Examples: + >>> manager = await AioSessionManager.get_instance() + >>> if await manager.has_active_session(): + ... print("Active session exists") + """ + async with self._session_lock: + return self._session is not None and not self._session.closed diff --git a/dashscope/common/session_manager.py b/dashscope/common/session_manager.py new file mode 100644 index 0000000..cdf321e --- /dev/null +++ b/dashscope/common/session_manager.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +import threading +from typing import Optional + +import requests +from requests.adapters import HTTPAdapter + +from dashscope.common.logging import logger + + +class ConnectionPoolConfig: + """ + 连接池配置类 + + 提供类型安全和参数验证的配置方式 + """ + + def __init__( + self, + pool_connections: int = 10, + pool_maxsize: int = 20, + max_retries: int = 3, + pool_block: bool = False, + ): + """ + 初始化连接池配置 + + Args: + pool_connections: 连接池大小,默认 10 + - 低并发(< 10 req/s): 10 + - 中并发(10-50 req/s): 20-30 + - 高并发(> 50 req/s): 50-100 + + pool_maxsize: 最大连接数,默认 20 + - 应该 >= pool_connections + - 低并发: 20 + - 中并发: 50 + - 高并发: 100-200 + + max_retries: 重试次数,默认 3 + - 网络稳定: 3 + - 网络不稳定: 5-10 + + pool_block: 连接池满时是否阻塞,默认 False + - False: 连接池满时创建新连接(推荐) + - True: 连接池满时等待可用连接 + """ + # 参数验证 + if pool_connections < 1: + raise ValueError("pool_connections 必须 >= 1") + if pool_maxsize < pool_connections: + raise ValueError("pool_maxsize 必须 >= pool_connections") + if max_retries < 0: + raise ValueError("max_retries 必须 >= 0") + + self.pool_connections = pool_connections + self.pool_maxsize = pool_maxsize + self.max_retries = max_retries + self.pool_block = pool_block + + def to_dict(self): + """转换为字典格式""" + return { + "pool_connections": self.pool_connections, + "pool_maxsize": self.pool_maxsize, + "max_retries": self.max_retries, + "pool_block": self.pool_block, + } + + def __repr__(self): + return ( + f"ConnectionPoolConfig(" + f"pool_connections={self.pool_connections}, " + f"pool_maxsize={self.pool_maxsize}, " + f"max_retries={self.max_retries}, " + f"pool_block={self.pool_block})" + ) + + +class SessionManager: + """ + 全局 HTTP Session 管理器 + + 特性: + 1. 线程安全的 Session 池 + 2. 支持全局启用/禁用连接复用 + 3. 支持自定义 Session 配置 + 4. 自动清理和重建机制 + """ + + _instance = None + _lock = threading.Lock() + + def __init__(self): + self._enabled = False # 默认关闭,保持向后兼容 + self._session = None + self._session_lock = threading.RLock() + self._config = ConnectionPoolConfig() # 使用配置类 + + @classmethod + def get_instance(cls): + """单例模式获取实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset_instance(cls): + """ + 重置单例实例(仅用于测试) + + 警告:此方法仅应在测试环境中使用 + """ + with cls._lock: + if cls._instance is not None: + cls._instance.disable() + cls._instance.reset() + cls._instance = None + + def enable( + self, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, + max_retries: Optional[int] = None, + pool_block: Optional[bool] = None, + ): + """ + 启用连接复用 + + Args: + pool_connections: 连接池大小,默认 10 + pool_maxsize: 最大连接数,默认 20 + max_retries: 重试次数,默认 3 + pool_block: 连接池满时是否阻塞,默认 False + + Examples: + # 使用默认配置 + enable() + + # 使用命名参数 + enable(pool_connections=50, pool_maxsize=100) + """ + with self._session_lock: + # 使用命名参数更新配置 + if pool_connections is not None: + self._config.pool_connections = pool_connections + if pool_maxsize is not None: + self._config.pool_maxsize = pool_maxsize + if max_retries is not None: + self._config.max_retries = max_retries + if pool_block is not None: + self._config.pool_block = pool_block + + # 参数验证 + if self._config.pool_maxsize < self._config.pool_connections: + raise ValueError( + f"pool_maxsize ({self._config.pool_maxsize}) 必须 >= " + f"pool_connections ({self._config.pool_connections})", + ) + + self._enabled = True + self._ensure_session() + logger.info( + "HTTP connection pool enabled with config: %s", + self._config, + ) + + def disable(self): + """禁用连接复用,关闭现有 Session""" + with self._session_lock: + self._enabled = False + if self._session: + try: + self._session.close() + except Exception as e: + logger.warning("Error closing session: %s", e) + finally: + self._session = None + logger.info("HTTP connection pool disabled") + + def is_enabled(self): + """检查是否启用连接复用""" + return self._enabled + + def get_config(self) -> ConnectionPoolConfig: + """ + 获取当前连接池配置 + + Returns: + ConnectionPoolConfig: 当前配置对象 + """ + return self._config + + def has_active_session(self) -> bool: + """ + 检查是否有活跃的 Session + + Returns: + bool: 如果存在活跃的 Session 返回 True,否则返回 False + """ + with self._session_lock: + return self._session is not None + + def _ensure_session(self): + """确保 Session 存在且有效(需要持有锁)""" + if self._session is None: + self._session = requests.Session() + + # 配置连接池 + adapter = HTTPAdapter( + pool_connections=self._config.pool_connections, + pool_maxsize=self._config.pool_maxsize, + max_retries=self._config.max_retries, + pool_block=self._config.pool_block, + ) + + self._session.mount("http://", adapter) + self._session.mount("https://", adapter) + logger.debug("Created new HTTP session with connection pool") + + def get_session(self) -> Optional[requests.Session]: + """ + 获取 Session 对象 + + Returns: + 如果启用了连接复用,返回全局 Session + 否则返回 None + + Examples: + >>> manager = SessionManager.get_instance() + >>> manager.enable() + >>> session = manager.get_session() + >>> if session: + ... response = session.get(url) + """ + if not self._enabled: + return None + + with self._session_lock: + self._ensure_session() + return self._session + + def reset(self): + """重置 Session(用于处理连接问题)""" + with self._session_lock: + if self._session: + try: + self._session.close() + except Exception as e: + logger.warning("Error closing session during reset: %s", e) + finally: + self._session = None + if self._enabled: + self._ensure_session() + logger.info("HTTP connection pool reset") + + def configure( + self, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, + max_retries: Optional[int] = None, + pool_block: Optional[bool] = None, + ): + """ + 更新配置并重建 Session + + Args: + pool_connections: 连接池大小 + pool_maxsize: 最大连接数 + max_retries: 重试次数 + pool_block: 连接池满时是否阻塞 + + Examples: + # 调整单个参数 + configure(pool_maxsize=100) + + # 调整多个参数 + configure(pool_connections=50, pool_maxsize=100) + """ + with self._session_lock: + # 使用命名参数更新配置 + if pool_connections is not None: + self._config.pool_connections = pool_connections + if pool_maxsize is not None: + self._config.pool_maxsize = pool_maxsize + if max_retries is not None: + self._config.max_retries = max_retries + if pool_block is not None: + self._config.pool_block = pool_block + + # 参数验证 + if self._config.pool_maxsize < self._config.pool_connections: + raise ValueError( + f"pool_maxsize ({self._config.pool_maxsize}) 必须 >= " + f"pool_connections ({self._config.pool_connections})", + ) + + if self._enabled: + # 重建 Session 以应用新配置 + if self._session: + try: + self._session.close() + except Exception as e: + logger.warning( + "Error closing session during configure: %s", + e, + ) + finally: + self._session = None + self._ensure_session() + logger.info("HTTP connection pool configured: %s", self._config) diff --git a/samples/test_generation.py b/samples/test_generation.py index 71cde15..d7f42f4 100644 --- a/samples/test_generation.py +++ b/samples/test_generation.py @@ -2,6 +2,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os + +import dashscope from dashscope import Generation @@ -10,6 +12,9 @@ class TestGeneration: @staticmethod def test_response_with_content(): + + dashscope.enable_http_connection_pool() + messages = [ {"role": "system", "content": "You are a helpful assistant."}, { diff --git a/tests/unit/test_aio_connection_pool.py b/tests/unit/test_aio_connection_pool.py new file mode 100644 index 0000000..dd8ae51 --- /dev/null +++ b/tests/unit/test_aio_connection_pool.py @@ -0,0 +1,510 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +"""异步连接池单元测试""" + +import asyncio + +import aiohttp +import pytest + +from dashscope.common.aio_session_manager import ( + AioConnectionPoolConfig, + AioSessionManager, +) + + +class TestAioConnectionPoolConfig: + """测试 AioConnectionPoolConfig 类""" + + def test_default_config(self): + """测试默认配置""" + config = AioConnectionPoolConfig() + assert config.limit == 100 + assert config.limit_per_host == 30 + assert config.ttl_dns_cache == 300 + assert config.keepalive_timeout == 30 + assert config.force_close is False + + def test_custom_config(self): + """测试自定义配置""" + config = AioConnectionPoolConfig( + limit=200, + limit_per_host=50, + ttl_dns_cache=600, + keepalive_timeout=60, + force_close=True, + ) + assert config.limit == 200 + assert config.limit_per_host == 50 + assert config.ttl_dns_cache == 600 + assert config.keepalive_timeout == 60 + assert config.force_close is True + + def test_config_validation(self): + """测试配置参数验证""" + # limit 必须 > 0 + with pytest.raises(ValueError, match=r"limit.*必须 > 0"): + AioConnectionPoolConfig(limit=0) + + # limit_per_host 必须 > 0 + with pytest.raises(ValueError, match=r"limit_per_host.*必须 > 0"): + AioConnectionPoolConfig(limit_per_host=0) + + # limit_per_host 必须 <= limit + with pytest.raises(ValueError, match=r"limit_per_host.*必须 <="): + AioConnectionPoolConfig(limit=50, limit_per_host=100) + + # ttl_dns_cache 必须 >= 0 + with pytest.raises(ValueError, match=r"ttl_dns_cache.*必须 >= 0"): + AioConnectionPoolConfig(ttl_dns_cache=-1) + + # keepalive_timeout 必须 >= 0 + with pytest.raises(ValueError, match=r"keepalive_timeout.*必须 >= 0"): + AioConnectionPoolConfig(keepalive_timeout=-1) + + def test_config_repr(self): + """测试配置的字符串表示""" + config = AioConnectionPoolConfig(limit=200, limit_per_host=50) + repr_str = repr(config) + assert "AioConnectionPoolConfig" in repr_str + assert "limit=200" in repr_str + assert "limit_per_host=50" in repr_str + + +class TestAioSessionManager: + """测试 AioSessionManager 类""" + + @pytest.fixture(autouse=True) + async def cleanup(self): + """每个测试后清理单例实例""" + yield + await AioSessionManager.reset_instance() + + @pytest.mark.asyncio + async def test_singleton_pattern(self): + """测试单例模式""" + manager1 = await AioSessionManager.get_instance() + manager2 = await AioSessionManager.get_instance() + assert manager1 is manager2 + + @pytest.mark.asyncio + async def test_default_state(self): + """测试默认状态""" + manager = await AioSessionManager.get_instance() + assert not manager.is_enabled() + assert not await manager.has_active_session() + config = manager.get_config() + assert config.limit == 100 + assert config.limit_per_host == 30 + + @pytest.mark.asyncio + async def test_enable(self): + """测试启用连接池""" + manager = await AioSessionManager.get_instance() + await manager.enable() + assert manager.is_enabled() + assert await manager.has_active_session() + + @pytest.mark.asyncio + async def test_enable_with_config(self): + """测试启用时配置参数""" + manager = await AioSessionManager.get_instance() + await manager.enable(limit=200, limit_per_host=50) + config = manager.get_config() + assert config.limit == 200 + assert config.limit_per_host == 50 + + @pytest.mark.asyncio + async def test_disable(self): + """测试禁用连接池""" + manager = await AioSessionManager.get_instance() + await manager.enable() + assert manager.is_enabled() + + await manager.disable() + assert not manager.is_enabled() + assert not await manager.has_active_session() + + @pytest.mark.asyncio + async def test_get_session(self): + """测试获取 Session""" + manager = await AioSessionManager.get_instance() + + # 禁用时返回 None + session = await manager.get_session() + assert session is None + + # 启用后返回 Session + await manager.enable() + session = await manager.get_session() + assert session is not None + assert isinstance(session, aiohttp.ClientSession) + assert not session.closed + + @pytest.mark.asyncio + async def test_get_session_direct(self): + """测试直接获取 Session""" + manager = await AioSessionManager.get_instance() + + # 禁用时返回 None + session = await manager.get_session_direct() + assert session is None + + # 启用后返回 Session + await manager.enable() + session = await manager.get_session_direct() + assert session is not None + + # 禁用后 Session 被关闭 + await manager.disable() + session = await manager.get_session_direct() + assert session is None + + @pytest.mark.asyncio + async def test_configure(self): + """测试配置连接池""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + # 配置参数 + await manager.configure(limit=200, limit_per_host=50) + config = manager.get_config() + assert config.limit == 200 + assert config.limit_per_host == 50 + + @pytest.mark.asyncio + async def test_configure_before_enable(self): + """测试启用前配置""" + manager = await AioSessionManager.get_instance() + + # 启用前配置不会创建 Session + await manager.configure(limit=200) + assert not await manager.has_active_session() + + # 启用后使用配置的参数 + await manager.enable() + config = manager.get_config() + assert config.limit == 200 + + @pytest.mark.asyncio + async def test_reset(self): + """测试重置连接池""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + old_session = await manager.get_session_direct() + assert old_session is not None + + # 重置后创建新 Session + await manager.reset() + new_session = await manager.get_session_direct() + assert new_session is not None + assert new_session is not old_session + + @pytest.mark.asyncio + async def test_reset_when_disabled(self): + """测试禁用状态下重置""" + manager = await AioSessionManager.get_instance() + await manager.enable() + await manager.disable() + + # 禁用状态下重置不会创建 Session + await manager.reset() + assert not await manager.has_active_session() + + @pytest.mark.asyncio + async def test_reset_instance(self): + """测试重置单例实例""" + manager1 = await AioSessionManager.get_instance() + await manager1.enable() + + await AioSessionManager.reset_instance() + + manager2 = await AioSessionManager.get_instance() + assert not manager2.is_enabled() + assert not await manager2.has_active_session() + + @pytest.mark.asyncio + async def test_session_reuse(self): + """测试 Session 复用""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + session1 = await manager.get_session() + session2 = await manager.get_session() + assert session1 is session2 + + @pytest.mark.asyncio + async def test_session_recreation_on_configure(self): + """测试配置变更时重新创建 Session""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + old_session = await manager.get_session_direct() + + # 配置变更后 Session 被重新创建 + await manager.configure(limit=200) + new_session = await manager.get_session_direct() + assert new_session is not old_session + + @pytest.mark.asyncio + async def test_concurrent_enable(self): + """测试并发启用""" + manager = await AioSessionManager.get_instance() + + # 并发启用 + await asyncio.gather( + manager.enable(), + manager.enable(), + manager.enable(), + ) + + assert manager.is_enabled() + assert await manager.has_active_session() + + @pytest.mark.asyncio + async def test_concurrent_get_session(self): + """测试并发获取 Session""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + # 并发获取 Session + sessions = await asyncio.gather( + manager.get_session(), + manager.get_session(), + manager.get_session(), + ) + + # 所有 Session 应该是同一个实例 + assert all(s is sessions[0] for s in sessions) + + @pytest.mark.asyncio + async def test_concurrent_enable_disable(self): + """测试并发启用和禁用""" + manager = await AioSessionManager.get_instance() + + async def enable_disable(): + await manager.enable() + await asyncio.sleep(0.01) + await manager.disable() + + # 并发执行启用和禁用 + await asyncio.gather( + enable_disable(), + enable_disable(), + enable_disable(), + ) + + # 最终状态应该是禁用 + assert not manager.is_enabled() + + @pytest.mark.asyncio + async def test_session_closed_detection(self): + """测试 Session 关闭检测""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + session = await manager.get_session_direct() + assert not session.closed + + # 手动关闭 Session + await session.close() + + # get_session 应该创建新的 Session + new_session = await manager.get_session() + assert new_session is not session + assert not new_session.closed + + +class TestAioConnectionPoolIntegration: + """测试异步连接池集成""" + + @pytest.fixture(autouse=True) + async def cleanup(self): + """每个测试后清理""" + yield + await AioSessionManager.reset_instance() + + @pytest.mark.asyncio + async def test_default_behavior_unchanged(self): + """测试默认行为不变""" + manager = await AioSessionManager.get_instance() + + # 默认禁用,不影响现有代码 + session = await manager.get_session() + assert session is None + + @pytest.mark.asyncio + async def test_enable_affects_all_requests(self): + """测试启用后影响所有请求""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + # 所有请求应该使用同一个 Session + session1 = await manager.get_session() + session2 = await manager.get_session() + assert session1 is session2 + + @pytest.mark.asyncio + async def test_disable_stops_reuse(self): + """测试禁用后停止复用""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + session_before = await manager.get_session() + assert session_before is not None + + await manager.disable() + + session_after = await manager.get_session() + assert session_after is None + + @pytest.mark.asyncio + async def test_multiple_enable_disable_cycles(self): + """测试多次启用/禁用循环""" + manager = await AioSessionManager.get_instance() + + for _ in range(3): + await manager.enable() + assert manager.is_enabled() + session = await manager.get_session() + assert session is not None + + await manager.disable() + assert not manager.is_enabled() + session = await manager.get_session() + assert session is None + + +class TestAioCustomSession: + """测试自定义异步 Session""" + + @pytest.fixture(autouse=True) + async def cleanup(self): + """每个测试后清理""" + yield + await AioSessionManager.reset_instance() + + @pytest.mark.asyncio + async def test_external_session_priority(self): + """测试外部 Session 优先级最高""" + from dashscope.api_entities.http_request import HttpRequest + + # 创建外部 Session + external_session = aiohttp.ClientSession() + + # 创建 HttpRequest(传入外部 Session) + http_request = HttpRequest( + url="https://example.com", + api_key="test_key", + http_method="POST", + aio_session=external_session, + ) + + # 验证外部 Session 被存储 + assert http_request.get_external_aio_session() is external_session + + await external_session.close() + + @pytest.mark.asyncio + async def test_external_session_overrides_global(self): + """测试外部 Session 覆盖全局连接池""" + from dashscope.api_entities.http_request import HttpRequest + + # 启用全局连接池 + manager = await AioSessionManager.get_instance() + await manager.enable() + + # 创建外部 Session + external_session = aiohttp.ClientSession() + + # 创建 HttpRequest(传入外部 Session) + http_request = HttpRequest( + url="https://example.com", + api_key="test_key", + http_method="POST", + aio_session=external_session, + ) + + # 验证使用外部 Session + assert http_request.get_external_aio_session() is external_session + + await external_session.close() + + +class TestAioConnectionPoolEdgeCases: + """测试异步连接池边界情况""" + + @pytest.fixture(autouse=True) + async def cleanup(self): + """每个测试后清理""" + yield + await AioSessionManager.reset_instance() + + @pytest.mark.asyncio + async def test_configure_partial_params(self): + """测试部分配置参数""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + # 只配置部分参数 + await manager.configure(limit=200) + config = manager.get_config() + assert config.limit == 200 + assert config.limit_per_host == 30 # 保持默认值 + + @pytest.mark.asyncio + async def test_enable_multiple_times(self): + """测试多次启用""" + manager = await AioSessionManager.get_instance() + + await manager.enable() + session1 = await manager.get_session_direct() + + await manager.enable() + session2 = await manager.get_session_direct() + + # 多次启用不会重新创建 Session + assert session1 is session2 + + @pytest.mark.asyncio + async def test_disable_multiple_times(self): + """测试多次禁用""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + await manager.disable() + await manager.disable() # 不应该报错 + + assert not manager.is_enabled() + + @pytest.mark.asyncio + async def test_reset_multiple_times(self): + """测试多次重置""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + await manager.reset() + await manager.reset() # 不应该报错 + + assert manager.is_enabled() + assert await manager.has_active_session() + + @pytest.mark.asyncio + async def test_configure_with_no_params(self): + """测试无参数配置""" + manager = await AioSessionManager.get_instance() + await manager.enable() + + old_config = manager.get_config() + await manager.configure() # 不传参数 + new_config = manager.get_config() + + # 配置应该保持不变 + assert old_config.limit == new_config.limit + assert old_config.limit_per_host == new_config.limit_per_host + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_connection_pool.py b/tests/unit/test_connection_pool.py new file mode 100644 index 0000000..48ec833 --- /dev/null +++ b/tests/unit/test_connection_pool.py @@ -0,0 +1,675 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +""" +HTTP 连接池功能单元测试 + +测试范围: +1. SessionManager 基本功能 +2. ConnectionPoolConfig 配置类 +3. HttpRequest 与 Session 集成 +4. 全局连接池 API +5. 自定义 Session 支持 +6. 线程安全性 +""" + +import threading +import time + +import pytest +import requests +from requests.adapters import HTTPAdapter + +import dashscope +from dashscope.common.session_manager import ( + SessionManager, + ConnectionPoolConfig, +) +from dashscope.api_entities.http_request import HttpRequest +from tests.unit.base_test import BaseTestEnvironment + + +class TestConnectionPoolConfig: + """测试 ConnectionPoolConfig 配置类""" + + def test_default_config(self): + """测试默认配置""" + config = ConnectionPoolConfig() + assert config.pool_connections == 10 + assert config.pool_maxsize == 20 + assert config.max_retries == 3 + assert config.pool_block is False + + def test_custom_config(self): + """测试自定义配置""" + config = ConnectionPoolConfig( + pool_connections=20, + pool_maxsize=50, + max_retries=5, + pool_block=True, + ) + assert config.pool_connections == 20 + assert config.pool_maxsize == 50 + assert config.max_retries == 5 + assert config.pool_block is True + + def test_config_validation(self): + """测试配置验证""" + # 测试负数验证 + with pytest.raises(ValueError, match="pool_connections 必须"): + ConnectionPoolConfig(pool_connections=0) + + with pytest.raises(ValueError, match="pool_maxsize 必须"): + ConnectionPoolConfig(pool_maxsize=0) + + with pytest.raises(ValueError, match="max_retries 必须"): + ConnectionPoolConfig(max_retries=-1) + + # 测试 pool_maxsize >= pool_connections + with pytest.raises( + ValueError, + match="pool_maxsize.*必须.*pool_connections", + ): + ConnectionPoolConfig(pool_connections=30, pool_maxsize=20) + + def test_config_to_dict(self): + """测试配置转换为字典""" + config = ConnectionPoolConfig( + pool_connections=15, + pool_maxsize=30, + max_retries=5, + pool_block=True, + ) + config_dict = config.to_dict() + assert config_dict == { + "pool_connections": 15, + "pool_maxsize": 30, + "max_retries": 5, + "pool_block": True, + } + + def test_config_str(self): + """测试配置字符串表示""" + config = ConnectionPoolConfig() + config_str = str(config) + assert "pool_connections=10" in config_str + assert "pool_maxsize=20" in config_str + assert "max_retries=3" in config_str + assert "pool_block=False" in config_str + + +class TestSessionManager: + """测试 SessionManager 单例类""" + + def setup_method(self): + """每个测试前重置 SessionManager""" + SessionManager.reset_instance() + + def teardown_method(self): + """每个测试后清理""" + manager = SessionManager.get_instance() + manager.reset() + + def test_singleton_pattern(self): + """测试单例模式""" + manager1 = SessionManager.get_instance() + manager2 = SessionManager.get_instance() + assert manager1 is manager2 + + def test_singleton_thread_safe(self): + """测试单例模式的线程安全性""" + instances = [] + + def get_instance(): + instances.append(SessionManager.get_instance()) + + threads = [threading.Thread(target=get_instance) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # 所有实例应该是同一个 + assert all(inst is instances[0] for inst in instances) + + def test_enable_disable(self): + """测试启用和禁用连接池""" + manager = SessionManager.get_instance() + + # 默认禁用 + assert not manager.is_enabled() + + # 启用 + manager.enable() + assert manager.is_enabled() + + # 禁用 + manager.disable() + assert not manager.is_enabled() + + def test_enable_with_config(self): + """测试使用配置启用连接池""" + manager = SessionManager.get_instance() + + manager.enable( + pool_connections=15, + pool_maxsize=30, + max_retries=5, + pool_block=True, + ) + + assert manager.is_enabled() + config = manager.get_config() + assert config.pool_connections == 15 + assert config.pool_maxsize == 30 + assert config.max_retries == 5 + assert config.pool_block is True + + def test_configure(self): + """测试配置连接池""" + manager = SessionManager.get_instance() + manager.enable() + + # 配置连接池 + manager.configure( + pool_connections=25, + pool_maxsize=50, + ) + + config = manager.get_config() + assert config.pool_connections == 25 + assert config.pool_maxsize == 50 + assert config.max_retries == 3 # 保持默认值 + + def test_get_session_when_disabled(self): + """测试禁用时获取 Session(直接方式)""" + manager = SessionManager.get_instance() + manager.disable() + + session = manager.get_session() + assert session is None + + def test_get_session_when_enabled(self): + """测试启用时获取 Session(直接方式)""" + manager = SessionManager.get_instance() + manager.enable() + + session = manager.get_session() + assert session is not None + assert isinstance(session, requests.Session) + + def test_get_session_returns_same_instance(self): + """测试获取 Session 返回同一实例""" + manager = SessionManager.get_instance() + manager.enable() + + session1 = manager.get_session() + session2 = manager.get_session() + assert session1 is session2 + + def test_get_session(self): + """测试直接获取 Session""" + manager = SessionManager.get_instance() + + # 启用时能获取 + manager.enable() + session = manager.get_session() + assert session is not None + assert isinstance(session, requests.Session) + + # 禁用时返回 None + manager.disable() + session = manager.get_session() + assert session is None + + def test_reset(self): + """测试重置连接池""" + manager = SessionManager.get_instance() + manager.enable() + + old_session = manager.get_session() + assert old_session is not None + + # 禁用后重置 + manager.disable() + manager.reset() + + # Session 应该被清理 + assert not manager.has_active_session() + assert not manager.is_enabled() + + # 重新启用后应该是新的 Session + manager.enable() + new_session = manager.get_session() + assert new_session is not old_session + + def test_session_has_adapter(self): + """测试 Session 配置了 HTTPAdapter""" + manager = SessionManager.get_instance() + manager.enable(pool_connections=15, pool_maxsize=30) + + session = manager.get_session() + assert session is not None + + # 检查是否配置了 HTTPAdapter + http_adapter = session.get_adapter("http://") + https_adapter = session.get_adapter("https://") + + assert isinstance(http_adapter, HTTPAdapter) + assert isinstance(https_adapter, HTTPAdapter) + + def test_thread_safe_session_creation(self): + """测试多线程环境下 Session 创建的线程安全性""" + manager = SessionManager.get_instance() + manager.enable() + + sessions = [] + + def get_session(): + sessions.append(manager.get_session()) + + threads = [threading.Thread(target=get_session) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join() + + # 所有线程应该获取到同一个 Session + assert all(s is sessions[0] for s in sessions) + + +class TestHttpRequestSessionIntegration: + """测试 HttpRequest 与 Session 的集成""" + + def test_http_request_accepts_session(self): + """测试 HttpRequest 接受 session 参数""" + custom_session = requests.Session() + + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + session=custom_session, + ) + + assert http_request.get_external_session() is custom_session + + def test_http_request_without_session(self): + """测试 HttpRequest 不传 session 参数""" + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + ) + + assert http_request.get_external_session() is None + + def test_http_request_uses_external_session_priority(self): + """测试 HttpRequest 优先使用外部传入的 Session""" + # 创建自定义 Session + custom_session = requests.Session() + custom_session.headers.update({"X-Test": "custom"}) + + # 创建 HttpRequest,传入自定义 Session + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + session=custom_session, + ) + + # 验证使用了自定义 Session + assert http_request.get_external_session() is custom_session + assert ( + http_request.get_external_session().headers.get("X-Test") + == "custom" + ) + + def test_http_request_session_priority(self): + """测试 Session 优先级:外部 > 全局 > 临时""" + # 1. 外部 Session 优先级最高 + custom_session = requests.Session() + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + session=custom_session, + ) + assert http_request.get_external_session() is custom_session + + # 2. 没有外部 Session 时,应该尝试使用全局 Session + http_request_no_session = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + ) + assert http_request_no_session.get_external_session() is None + + +class TestGlobalConnectionPoolAPI(BaseTestEnvironment): + """测试全局连接池 API""" + + def setup_method(self): + """每个测试前重置""" + super().setup_class() + SessionManager.reset_instance() + + def teardown_method(self): + """每个测试后清理""" + dashscope.disable_http_connection_pool() + super().teardown_class() + + def test_enable_http_connection_pool(self): + """测试启用 HTTP 连接池""" + dashscope.enable_http_connection_pool() + + manager = SessionManager.get_instance() + assert manager.is_enabled() + + def test_enable_http_connection_pool_with_params(self): + """测试使用参数启用 HTTP 连接池""" + dashscope.enable_http_connection_pool( + pool_connections=15, + pool_maxsize=30, + max_retries=5, + pool_block=True, + ) + + manager = SessionManager.get_instance() + assert manager.is_enabled() + + config = manager.get_config() + assert config.pool_connections == 15 + assert config.pool_maxsize == 30 + assert config.max_retries == 5 + assert config.pool_block is True + + def test_disable_http_connection_pool(self): + """测试禁用 HTTP 连接池""" + dashscope.enable_http_connection_pool() + assert SessionManager.get_instance().is_enabled() + + dashscope.disable_http_connection_pool() + assert not SessionManager.get_instance().is_enabled() + + def test_reset_http_connection_pool(self): + """测试重置 HTTP 连接池""" + dashscope.enable_http_connection_pool() + # 验证 session 存在 + assert SessionManager.get_instance().get_session() is not None + + # 禁用后重置 + dashscope.disable_http_connection_pool() + dashscope.reset_http_connection_pool() + + # Session 应该被清理 + manager = SessionManager.get_instance() + assert not manager.has_active_session() + assert not manager.is_enabled() + + def test_configure_http_connection_pool(self): + """测试配置 HTTP 连接池""" + dashscope.enable_http_connection_pool() + + dashscope.configure_http_connection_pool( + pool_connections=25, + pool_maxsize=50, + ) + + config = SessionManager.get_instance().get_config() + assert config.pool_connections == 25 + assert config.pool_maxsize == 50 + + def test_configure_before_enable(self): + """测试在启用前配置""" + # 先启用 + dashscope.enable_http_connection_pool() + + # 然后配置 + dashscope.configure_http_connection_pool( + pool_connections=20, + pool_maxsize=40, + ) + + manager = SessionManager.get_instance() + assert manager.is_enabled() + + config = manager.get_config() + assert config.pool_connections == 20 + assert config.pool_maxsize == 40 + + +class TestCustomSessionSupport: + """测试自定义 Session 支持""" + + def test_custom_session_with_headers(self): + """测试自定义 Session 带请求头""" + session = requests.Session() + session.headers.update({"X-Custom-Header": "TestValue"}) + + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + session=session, + ) + + assert http_request.get_external_session() is session + assert session.headers.get("X-Custom-Header") == "TestValue" + + def test_custom_session_with_proxies(self): + """测试自定义 Session 带代理""" + session = requests.Session() + session.proxies = {"https": "https://proxy.example.com:8080"} + + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + session=session, + ) + + assert http_request.get_external_session() is session + assert session.proxies.get("https") == "https://proxy.example.com:8080" + + def test_custom_session_with_adapter(self): + """测试自定义 Session 带自定义 Adapter""" + session = requests.Session() + adapter = HTTPAdapter( + pool_connections=50, + pool_maxsize=100, + ) + session.mount("https://", adapter) + session.mount("http://", adapter) + + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + session=session, + ) + + assert http_request.get_external_session() is session + + # 验证 Adapter 配置 + http_adapter = session.get_adapter("http://") + assert isinstance(http_adapter, HTTPAdapter) + + +class TestThreadSafety: + """测试线程安全性""" + + def setup_method(self): + """每个测试前重置""" + SessionManager.reset_instance() + + def teardown_method(self): + """每个测试后清理""" + manager = SessionManager.get_instance() + manager.reset() + + def test_concurrent_enable_disable(self): + """测试并发启用和禁用""" + manager = SessionManager.get_instance() + + def toggle_enable(): + for _ in range(10): + manager.enable() + time.sleep(0.001) + manager.disable() + time.sleep(0.001) + + threads = [threading.Thread(target=toggle_enable) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # 不应该抛出异常 + assert True + + def test_concurrent_get_session(self): + """测试并发获取 Session""" + manager = SessionManager.get_instance() + manager.enable() + + sessions = [] + + def get_session(): + for _ in range(10): + s = manager.get_session() + sessions.append(s) + time.sleep(0.001) + + threads = [threading.Thread(target=get_session) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # 所有获取的 Session 应该是同一个 + assert all(s is sessions[0] for s in sessions) + + def test_concurrent_configure(self): + """测试并发配置""" + manager = SessionManager.get_instance() + manager.enable() + + def configure(): + for i in range(5): + manager.configure( + pool_connections=10 + i, + pool_maxsize=20 + i * 2, + ) + time.sleep(0.001) + + threads = [threading.Thread(target=configure) for _ in range(5)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # 不应该抛出异常,最终配置应该是有效的 + config = manager.get_config() + assert config.pool_connections > 0 + assert config.pool_maxsize >= config.pool_connections + + +class TestEdgeCases: + """测试边界情况""" + + def setup_method(self): + """每个测试前重置""" + SessionManager.reset_instance() + + def teardown_method(self): + """每个测试后清理""" + manager = SessionManager.get_instance() + manager.reset() + + def test_enable_multiple_times(self): + """测试多次启用""" + manager = SessionManager.get_instance() + + manager.enable() + session1 = manager.get_session() + + manager.enable() + session2 = manager.get_session() + + # 应该返回同一个 Session + assert session1 is session2 + + def test_configure_with_partial_params(self): + """测试部分参数配置""" + manager = SessionManager.get_instance() + manager.enable() + + # 只配置部分参数 + manager.configure(pool_connections=15) + + config = manager.get_config() + assert config.pool_connections == 15 + assert config.pool_maxsize == 20 # 保持默认值 + assert config.max_retries == 3 # 保持默认值 + + def test_reset_when_disabled(self): + """测试禁用状态下重置""" + manager = SessionManager.get_instance() + manager.disable() + + # 不应该抛出异常 + manager.reset() + assert not manager.is_enabled() + + def test_get_session_after_reset(self): + """测试重置后获取 Session""" + manager = SessionManager.get_instance() + manager.enable() + + old_session = manager.get_session() + + # 禁用后重置 + manager.disable() + manager.reset() + + # 重置后应该返回 None + assert manager.get_session() is None + + # 重新启用后应该是新的 Session + manager.enable() + new_session = manager.get_session() + assert new_session is not None + assert new_session is not old_session + + +class TestBackwardCompatibility: + """测试向后兼容性""" + + def test_http_request_without_session_param(self): + """测试不传 session 参数的 HttpRequest(向后兼容)""" + # 不传 session 参数应该正常工作 + http_request = HttpRequest( + url="http://example.com/api", + api_key="test-key", + http_method="POST", + ) + + assert http_request.get_external_session() is None + + def test_default_behavior_unchanged(self): + """测试默认行为未改变(需要在干净环境中测试)""" + # 重置到初始状态 + manager = SessionManager.get_instance() + manager.disable() + manager.reset() + + # 默认应该是禁用状态 + assert not manager.is_enabled() + + # 默认获取 Session 应该返回 None + assert manager.get_session() is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])