Skip to content

Commit e89583f

Browse files
jameszyaoSimsonW
authored andcommitted
feat: chat completion in stream mode
1 parent 1a71022 commit e89583f

File tree

6 files changed

+299
-61
lines changed

6 files changed

+299
-61
lines changed

taskingai/client/api/inference_api.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# python 2 and python 3 compatibility library
1414
import six
1515

16-
from taskingai.client.api_client import SyncApiClient
17-
16+
from ..api_client import SyncApiClient
17+
from ..stream import Stream
18+
from ..models import INFERENCE_CHAT_COMPLETION_STREAM_CAST_MAP
1819

1920
class InferenceApi(object):
2021

@@ -23,38 +24,31 @@ def __init__(self, api_client=None):
2324
api_client = SyncApiClient()
2425
self.api_client = api_client
2526

26-
def chat_completion(self, body, **kwargs): # noqa: E501
27+
def chat_completion(self, body, stream = False, **kwargs): # noqa: E501
2728
"""Chat Completion # noqa: E501
2829
2930
Model inference for chat completion. # noqa: E501
30-
This method makes a synchronous HTTP request by default. To make an
31-
asynchronous HTTP request, please pass async_req=True
32-
>>> thread = api.chat_completion(body, async_req=True)
33-
>>> result = thread.get()
34-
35-
:param async_req bool
3631
:param ChatCompletionRequest body: (required)
3732
:return: object
3833
If the method is called asynchronously,
3934
returns the request thread.
4035
"""
4136
kwargs['_return_http_data_only'] = True
42-
if kwargs.get('async_req'):
43-
return self.chat_completion_with_http_info(body, **kwargs) # noqa: E501
37+
cast_map = INFERENCE_CHAT_COMPLETION_STREAM_CAST_MAP
38+
response = self.chat_completion_with_http_info(body, stream, **kwargs)
39+
if not stream:
40+
return response
4441
else:
45-
(data) = self.chat_completion_with_http_info(body, **kwargs) # noqa: E501
46-
return data
42+
return Stream(
43+
cast_map=cast_map,
44+
response=response,
45+
client=self.api_client
46+
)
4747

48-
def chat_completion_with_http_info(self, body, **kwargs): # noqa: E501
48+
def chat_completion_with_http_info(self, body, stream, **kwargs): # noqa: E501
4949
"""Chat Completion # noqa: E501
5050
5151
Model inference for chat completion. # noqa: E501
52-
This method makes a synchronous HTTP request by default. To make an
53-
asynchronous HTTP request, please pass async_req=True
54-
>>> thread = api.chat_completion_with_http_info(body, async_req=True)
55-
>>> result = thread.get()
56-
57-
:param async_req bool
5852
:param ChatCompletionRequest body: (required)
5953
:return: object
6054
If the method is called asynchronously,
@@ -106,7 +100,7 @@ def chat_completion_with_http_info(self, body, **kwargs): # noqa: E501
106100
# Authentication setting
107101
auth_settings = [] # noqa: E501
108102

109-
return self.api_client.call_api(
103+
response = self.api_client.call_api(
110104
'/v1/inference/chat_completion', 'POST',
111105
path_params,
112106
query_params,
@@ -119,7 +113,10 @@ def chat_completion_with_http_info(self, body, **kwargs): # noqa: E501
119113
_return_http_data_only=params.get('_return_http_data_only'),
120114
_preload_content=params.get('_preload_content', True),
121115
_request_timeout=params.get('_request_timeout'),
122-
collection_formats=collection_formats)
116+
collection_formats=collection_formats,
117+
stream=stream
118+
)
119+
return response
123120

124121
def text_embedding(self, body, **kwargs): # noqa: E501
125122
"""Text Embedding # noqa: E501
@@ -137,11 +134,8 @@ def text_embedding(self, body, **kwargs): # noqa: E501
137134
returns the request thread.
138135
"""
139136
kwargs['_return_http_data_only'] = True
140-
if kwargs.get('async_req'):
141-
return self.text_embedding_with_http_info(body, **kwargs) # noqa: E501
142-
else:
143-
(data) = self.text_embedding_with_http_info(body, **kwargs) # noqa: E501
144-
return data
137+
(data) = self.text_embedding_with_http_info(body, **kwargs) # noqa: E501
138+
return data
145139

146140
def text_embedding_with_http_info(self, body, **kwargs): # noqa: E501
147141
"""Text Embedding # noqa: E501

taskingai/client/api_client.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def __call_api(
437437
query_params=None, header_params=None, body=None, post_params=None,
438438
files=None, response_type=None, auth_settings=None,
439439
_return_http_data_only=None, collection_formats=None,
440-
_preload_content=True, _request_timeout=None):
440+
_preload_content=True, _request_timeout=None, stream=False):
441441

442442
config = self.configuration
443443

@@ -490,11 +490,15 @@ def __call_api(
490490

491491
# perform request and return response
492492
response_data = self.request(
493-
method, url, query_params=query_params, headers=header_params,
493+
method, url, stream=stream,
494+
query_params=query_params, headers=header_params,
494495
post_params=post_params, body=body,
495496
_preload_content=_preload_content,
496497
_request_timeout=_request_timeout)
497498

499+
if stream:
500+
return response_data
501+
498502
self.last_response = response_data
499503

500504
return_data = response_data
@@ -517,7 +521,7 @@ def call_api(self, resource_path, method,
517521
body=None, post_params=None, files=None,
518522
response_type=None, auth_settings=None,
519523
_return_http_data_only=None, collection_formats=None,
520-
_preload_content=True, _request_timeout=None):
524+
_preload_content=True, _request_timeout=None, stream=False):
521525
"""Makes the HTTP request (synchronous) and returns deserialized data.
522526
523527
:param resource_path: Path to method endpoint.
@@ -551,27 +555,30 @@ def call_api(self, resource_path, method,
551555
body, post_params, files,
552556
response_type, auth_settings,
553557
_return_http_data_only, collection_formats,
554-
_preload_content, _request_timeout)
558+
_preload_content, _request_timeout, stream)
555559

556560

557-
def request(self, method, url, query_params=None, headers=None,
561+
def request(self, method, url, stream=False, query_params=None, headers=None,
558562
post_params=None, body=None, _preload_content=True,
559563
_request_timeout=None):
560564
"""Makes the HTTP request using RESTClient."""
561565
if method == "GET":
562566
return self.rest_client.GET(url,
567+
stream=stream,
563568
query_params=query_params,
564569
_preload_content=_preload_content,
565570
_request_timeout=_request_timeout,
566571
headers=headers)
567572
elif method == "HEAD":
568573
return self.rest_client.HEAD(url,
574+
stream=stream,
569575
query_params=query_params,
570576
_preload_content=_preload_content,
571577
_request_timeout=_request_timeout,
572578
headers=headers)
573579
elif method == "OPTIONS":
574580
return self.rest_client.OPTIONS(url,
581+
stream=stream,
575582
query_params=query_params,
576583
headers=headers,
577584
post_params=post_params,
@@ -580,6 +587,7 @@ def request(self, method, url, query_params=None, headers=None,
580587
body=body)
581588
elif method == "POST":
582589
return self.rest_client.POST(url,
590+
stream=stream,
583591
query_params=query_params,
584592
headers=headers,
585593
post_params=post_params,
@@ -588,6 +596,7 @@ def request(self, method, url, query_params=None, headers=None,
588596
body=body)
589597
elif method == "PUT":
590598
return self.rest_client.PUT(url,
599+
stream=stream,
591600
query_params=query_params,
592601
headers=headers,
593602
post_params=post_params,
@@ -596,6 +605,7 @@ def request(self, method, url, query_params=None, headers=None,
596605
body=body)
597606
elif method == "PATCH":
598607
return self.rest_client.PATCH(url,
608+
stream=stream,
599609
query_params=query_params,
600610
headers=headers,
601611
post_params=post_params,
@@ -604,6 +614,7 @@ def request(self, method, url, query_params=None, headers=None,
604614
body=body)
605615
elif method == "DELETE":
606616
return self.rest_client.DELETE(url,
617+
stream=stream,
607618
query_params=query_params,
608619
headers=headers,
609620
_preload_content=_preload_content,

taskingai/client/models/entity/inference/chat_completion.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"ChatCompletionFunctionCall",
1717
"ChatCompletionFunction",
1818
"ChatCompletionFinishReason",
19+
"ChatCompletionChunk",
20+
"INFERENCE_CHAT_COMPLETION_STREAM_CAST_MAP"
1921
]
2022

2123
class ChatCompletionRole(str, Enum):
@@ -71,3 +73,17 @@ class ChatCompletion(TaskingaiBaseModel):
7173
finish_reason: ChatCompletionFinishReason
7274
message: ChatCompletionAssistantMessage
7375
created_timestamp: int
76+
77+
78+
class ChatCompletionChunk(TaskingaiBaseModel):
79+
object: str
80+
role: ChatCompletionRole
81+
index: int
82+
delta: str
83+
created_timestamp: int
84+
85+
86+
INFERENCE_CHAT_COMPLETION_STREAM_CAST_MAP = {
87+
"ChatCompletion": ChatCompletion,
88+
"ChatCompletionChunk": ChatCompletionChunk
89+
}

taskingai/client/rest.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,40 +41,39 @@ def getheader(self, name, default=None):
4141
class RESTSyncClientObject(object):
4242

4343
def __init__(self, configuration, pools_size=4, maxsize=None):
44-
# 设置连接池的最大并发连接数
44+
# set default user agent
4545
if maxsize is None:
4646
maxsize = configuration.connection_pool_maxsize if configuration.connection_pool_maxsize is not None else 4
4747

48-
# 设置连接限制
4948
limits = httpx.Limits(max_connections=maxsize, max_keepalive_connections=pools_size)
5049

51-
# 设置 SSL 配置
50+
# SSL configuration
5251
verify = configuration.ssl_ca_cert or True # 如果提供了自定义 CA 证书则使用,否则默认启用 SSL 验证
5352
if not configuration.verify_ssl:
5453
verify = False # 如果明确指定不进行 SSL 验证,则设置为 False
5554

56-
# 设置代理
55+
# proxy configuration
5756
proxies = None
5857
if configuration.proxy:
5958
proxies = {
6059
'http://': configuration.proxy,
6160
'https://': configuration.proxy,
6261
}
6362

64-
# 创建 httpx 客户端
63+
# create httpx client
6564
self.client = httpx.Client(
6665
limits=limits,
6766
verify=verify,
6867
proxies=proxies,
6968
)
7069

71-
# 如果有提供客户端证书,设置之
70+
# set client cert if provided
7271
if configuration.cert_file and configuration.key_file:
7372
self.client.cert = (configuration.cert_file, configuration.key_file)
7473

75-
def request(self, method, url, query_params=None, headers=None,
74+
def request(self, method, url, stream = False, query_params=None, headers=None,
7675
body=None, post_params=None, _preload_content=True,
77-
_request_timeout=None):
76+
_request_timeout=None) -> RESTResponse | httpx.Response:
7877
"""
7978
Perform asynchronous HTTP requests.
8079
@@ -110,13 +109,23 @@ def request(self, method, url, query_params=None, headers=None,
110109
request_body = json.dumps(body) if body is not None else None
111110

112111
try:
113-
r = self.client.request(
114-
method, url,
115-
params=query_params,
116-
headers=headers,
117-
content=request_body,
118-
timeout=_request_timeout
119-
)
112+
if stream:
113+
with self.client.stream(
114+
method, url,
115+
params=query_params,
116+
headers=headers,
117+
content=request_body,
118+
timeout=_request_timeout
119+
) as r:
120+
return r
121+
else:
122+
r = self.client.request(
123+
method, url,
124+
params=query_params,
125+
headers=headers,
126+
content=request_body,
127+
timeout=_request_timeout
128+
)
120129
except HTTPError as e:
121130
msg = "{0}\n{1}".format(type(e).__name__, str(e))
122131
raise ApiException(status=0, reason=msg)
@@ -129,72 +138,78 @@ def request(self, method, url, query_params=None, headers=None,
129138

130139
return r
131140

132-
def GET(self, url, headers=None, query_params=None, _preload_content=True,
141+
def GET(self, url, stream=False, headers=None, query_params=None, _preload_content=True,
133142
_request_timeout=None):
134143
return self.request("GET", url,
144+
stream=stream,
135145
headers=headers,
136146
_preload_content=_preload_content,
137147
_request_timeout=_request_timeout,
138148
query_params=query_params)
139149

140-
def HEAD(self, url, headers=None, query_params=None, _preload_content=True,
150+
def HEAD(self, url, stream=False, headers=None, query_params=None, _preload_content=True,
141151
_request_timeout=None):
142152
return self.request("HEAD", url,
153+
stream=stream,
143154
headers=headers,
144155
_preload_content=_preload_content,
145156
_request_timeout=_request_timeout,
146157
query_params=query_params)
147158

148-
def OPTIONS(self, url, headers=None, query_params=None, post_params=None,
159+
def OPTIONS(self, url, stream=False, headers=None, query_params=None, post_params=None,
149160
body=None, _preload_content=True, _request_timeout=None):
150161
return self.request("OPTIONS", url,
162+
stream=stream,
151163
headers=headers,
152164
query_params=query_params,
153165
post_params=post_params,
154166
_preload_content=_preload_content,
155167
_request_timeout=_request_timeout,
156168
body=body)
157169

158-
def DELETE(self, url, headers=None, query_params=None, body=None,
170+
def DELETE(self, url, stream=False, headers=None, query_params=None, body=None,
159171
_preload_content=True, _request_timeout=None):
160172
return self.request("DELETE", url,
173+
stream=stream,
161174
headers=headers,
162175
query_params=query_params,
163176
_preload_content=_preload_content,
164177
_request_timeout=_request_timeout,
165178
body=body)
166179

167-
def POST(self, url, headers=None, query_params=None, post_params=None,
180+
def POST(self, url, stream=False, headers=None, query_params=None, post_params=None,
168181
body=None, _preload_content=True, _request_timeout=None):
169182
return self.request("POST", url,
183+
stream=stream,
170184
headers=headers,
171185
query_params=query_params,
172186
post_params=post_params,
173187
_preload_content=_preload_content,
174188
_request_timeout=_request_timeout,
175189
body=body)
176190

177-
def PUT(self, url, headers=None, query_params=None, post_params=None,
191+
def PUT(self, url, stream=False, headers=None, query_params=None, post_params=None,
178192
body=None, _preload_content=True, _request_timeout=None):
179193
return self.request("PUT", url,
180194
headers=headers,
195+
stream=stream,
181196
query_params=query_params,
182197
post_params=post_params,
183198
_preload_content=_preload_content,
184199
_request_timeout=_request_timeout,
185200
body=body)
186201

187-
def PATCH(self, url, headers=None, query_params=None, post_params=None,
202+
def PATCH(self, url, stream=False, headers=None, query_params=None, post_params=None,
188203
body=None, _preload_content=True, _request_timeout=None):
189204
return self.request("PATCH", url,
205+
stream=stream,
190206
headers=headers,
191207
query_params=query_params,
192208
post_params=post_params,
193209
_preload_content=_preload_content,
194210
_request_timeout=_request_timeout,
195211
body=body)
196212

197-
198213
class RESTAsyncClientObject(object):
199214

200215
def __init__(self, configuration, pools_size=4, maxsize=None):

0 commit comments

Comments
 (0)