X Tutup
Skip to content

Commit af29c82

Browse files
committed
unify SSE filter, use simplified parsing without lpeg, fixes #7
1 parent 249a195 commit af29c82

File tree

11 files changed

+293
-175
lines changed

11 files changed

+293
-175
lines changed

lua-openai-dev-1.rockspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ that supports the LuaSocket request interface. Compatible with OpenResty using
1414
}
1515
dependencies = {
1616
"lua >= 5.1",
17-
"lpeg",
1817
"lua-cjson",
1918
"tableshape",
2019
"luasocket",
@@ -24,6 +23,7 @@ build = {
2423
type = "builtin",
2524
modules = {
2625
["openai"] = "openai/init.lua",
26+
["openai.sse"] = "openai/sse.lua",
2727
["openai.chat_completions"] = "openai/chat_completions.lua",
2828
["openai.responses"] = "openai/responses.lua",
2929
["openai.compat.gemini"] = "openai/compat/gemini.lua"

openai/chat_completions.lua

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -84,49 +84,8 @@ local parse_completion_chunk = types.partial({
8484
}) % function(value, state)
8585
return setmetatable(state, completion_chunk_mt)
8686
end
87-
local consume_json_head
88-
do
89-
local C, S, P
90-
do
91-
local _obj_0 = require("lpeg")
92-
C, S, P = _obj_0.C, _obj_0.S, _obj_0.P
93-
end
94-
local consume_json = P(function(str, pos)
95-
local str_len = #str
96-
for k = pos + 1, str_len do
97-
local candidate = str:sub(pos, k)
98-
local parsed = false
99-
pcall(function()
100-
parsed = cjson.decode(candidate)
101-
end)
102-
if parsed then
103-
return k + 1
104-
end
105-
end
106-
return nil
107-
end)
108-
consume_json_head = S("\t\n\r ") ^ 0 * P("data: ") * C(consume_json) * C(P(1) ^ 0)
109-
end
110-
local create_chat_stream_filter
111-
create_chat_stream_filter = function(chunk_callback)
112-
assert(types["function"](chunk_callback), "Must provide chunk_callback function when streaming response")
113-
local accumulation_buffer = ""
114-
return function(...)
115-
local chunk = ...
116-
if type(chunk) == "string" then
117-
accumulation_buffer = accumulation_buffer .. chunk
118-
while true do
119-
local json_blob, rest = consume_json_head:match(accumulation_buffer)
120-
if not (json_blob) then
121-
break
122-
end
123-
accumulation_buffer = rest
124-
chunk_callback(cjson.decode(json_blob))
125-
end
126-
end
127-
return ...
128-
end
129-
end
87+
local create_stream_filter
88+
create_stream_filter = require("openai.sse").create_stream_filter
13089
local ChatSession
13190
do
13291
local _class_0
@@ -187,7 +146,7 @@ do
187146
if stream_callback then
188147
assert(type(response) == "string", "Expected string response from streaming output")
189148
local parts = { }
190-
local f = create_chat_stream_filter(function(c)
149+
local f = create_stream_filter(function(c)
191150
do
192151
local parsed = parse_completion_chunk(c)
193152
if parsed then
@@ -260,6 +219,5 @@ end
260219
return {
261220
ChatSession = ChatSession,
262221
test_message = test_message,
263-
create_chat_stream_filter = create_chat_stream_filter,
264222
parse_completion_chunk = parse_completion_chunk
265223
}

openai/chat_completions.moon

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -105,52 +105,7 @@ parse_completion_chunk = types.partial({
105105
}
106106
}) % (value, state) -> setmetatable state, completion_chunk_mt
107107

108-
-- lpeg pattern to read a json data block from the front of a string, returns
109-
-- the json blob and the rest of the string if it could parse one
110-
consume_json_head = do
111-
import C, S, P from require "lpeg"
112-
113-
-- this pattern reads from the front just enough characters to consume a
114-
-- valid json object
115-
consume_json = P (str, pos) ->
116-
str_len = #str
117-
for k=pos+1,str_len
118-
candidate = str\sub pos, k
119-
parsed = false
120-
pcall -> parsed = cjson.decode candidate
121-
if parsed
122-
return k + 1
123-
124-
return nil -- fail
125-
126-
S("\t\n\r ")^0 * P("data: ") * C(consume_json) * C(P(1)^0)
127-
128-
129-
-- creates a ltn12 compatible filter function that will call chunk_callback
130-
-- for each parsed json chunk from the server-sent events api response
131-
create_chat_stream_filter = (chunk_callback) ->
132-
assert types.function(chunk_callback), "Must provide chunk_callback function when streaming response"
133-
134-
accumulation_buffer = ""
135-
136-
(...) ->
137-
chunk = ...
138-
139-
if type(chunk) == "string"
140-
accumulation_buffer ..= chunk
141-
142-
while true
143-
json_blob, rest = consume_json_head\match accumulation_buffer
144-
unless json_blob
145-
break
146-
147-
accumulation_buffer = rest
148-
chunk_callback cjson.decode json_blob
149-
-- if chunk = parse_completion_chunk cjson.decode json_blob
150-
-- chunk_callback chunk
151-
152-
...
153-
108+
import create_stream_filter from require "openai.sse"
154109

155110
-- handles appending response for each call to chat
156111
-- TODO: hadle appending the streaming response to the output
@@ -219,7 +174,7 @@ class ChatSession
219174
"Expected string response from streaming output"
220175

221176
parts = {}
222-
f = create_chat_stream_filter (c) ->
177+
f = create_stream_filter (c) ->
223178
if parsed = parse_completion_chunk c
224179
table.insert parts, parsed.content
225180

@@ -256,6 +211,5 @@ class ChatSession
256211
{
257212
:ChatSession
258213
:test_message
259-
:create_chat_stream_filter
260214
:parse_completion_chunk
261215
}

openai/init.lua

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ do
2828
if chunk_callback == nil then
2929
chunk_callback = nil
3030
end
31-
local test_message, create_chat_stream_filter
32-
do
33-
local _obj_0 = require("openai.chat_completions")
34-
test_message, create_chat_stream_filter = _obj_0.test_message, _obj_0.create_chat_stream_filter
35-
end
31+
local test_message
32+
test_message = require("openai.chat_completions").test_message
33+
local create_stream_filter
34+
create_stream_filter = require("openai.sse").create_stream_filter
3635
local test_messages = types.array_of(test_message)
3736
assert(test_messages(messages))
3837
local payload = {
@@ -46,7 +45,7 @@ do
4645
end
4746
local stream_filter
4847
if payload.stream then
49-
stream_filter = create_chat_stream_filter(chunk_callback)
48+
stream_filter = create_stream_filter(chunk_callback)
5049
end
5150
return self:_request("POST", "/chat/completions", payload, nil, stream_filter)
5251
end,
@@ -164,8 +163,8 @@ do
164163
if stream_callback == nil then
165164
stream_callback = nil
166165
end
167-
local create_response_stream_filter
168-
create_response_stream_filter = require("openai.responses").create_response_stream_filter
166+
local create_stream_filter
167+
create_stream_filter = require("openai.sse").create_stream_filter
169168
local payload = {
170169
model = self.default_model,
171170
input = input
@@ -177,7 +176,7 @@ do
177176
end
178177
local stream_filter
179178
if payload.stream and stream_callback then
180-
stream_filter = create_response_stream_filter(stream_callback)
179+
stream_filter = create_stream_filter(stream_callback)
181180
end
182181
return self:_request("POST", "/responses", payload, nil, stream_filter)
183182
end,

openai/init.moon

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class OpenAI
3838
-- opts: additional parameters as described in https://platform.openai.com/docs/api-reference/chat, eg. model, temperature, etc.
3939
-- completion_callback: function to be called for parsed streaming output when stream = true is passed to opts
4040
create_chat_completion: (messages, opts, chunk_callback=nil) =>
41-
import test_message, create_chat_stream_filter from require "openai.chat_completions"
41+
import test_message from require "openai.chat_completions"
42+
import create_stream_filter from require "openai.sse"
4243

4344
test_messages = types.array_of test_message
4445
assert test_messages messages
@@ -53,7 +54,7 @@ class OpenAI
5354
payload[k] = v
5455

5556
stream_filter = if payload.stream
56-
create_chat_stream_filter chunk_callback
57+
create_stream_filter chunk_callback
5758

5859
@_request "POST", "/chat/completions", payload, nil, stream_filter
5960

@@ -176,7 +177,7 @@ class OpenAI
176177
-- stream_callback: optional function for streaming responses
177178
-- Returns: status, response, headers (raw result from _request)
178179
create_response: (input, opts={}, stream_callback=nil) =>
179-
import create_response_stream_filter from require "openai.responses"
180+
import create_stream_filter from require "openai.sse"
180181

181182
payload = {
182183
model: @default_model
@@ -188,7 +189,7 @@ class OpenAI
188189
payload[k] = v
189190

190191
stream_filter = if payload.stream and stream_callback
191-
create_response_stream_filter stream_callback
192+
create_stream_filter stream_callback
192193

193194
@_request "POST", "/responses", payload, nil, stream_filter
194195

openai/responses.lua

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -129,38 +129,6 @@ local parse_responses_response = types.partial({
129129
usage = empty + types.table:tag("usage"),
130130
status = empty + types.string:tag("status")
131131
})
132-
local create_response_stream_filter
133-
create_response_stream_filter = function(chunk_callback)
134-
assert(types["function"](chunk_callback), "Must provide chunk_callback function when streaming response")
135-
local buffer = ""
136-
return function(...)
137-
local chunk = ...
138-
if type(chunk) == "string" then
139-
buffer = buffer .. chunk
140-
while true do
141-
local newline_pos = buffer:find("\n")
142-
if not (newline_pos) then
143-
break
144-
end
145-
local line = buffer:sub(1, newline_pos - 1)
146-
buffer = buffer:sub(newline_pos + 1)
147-
line = line:gsub("%s*$", "")
148-
if line:match("^data: ") then
149-
local json_data = line:sub(7)
150-
if json_data ~= "[DONE]" then
151-
local success, parsed = pcall(function()
152-
return cjson.decode(json_data)
153-
end)
154-
if success then
155-
chunk_callback(parsed)
156-
end
157-
end
158-
end
159-
end
160-
end
161-
return ...
162-
end
163-
end
164132
local ResponsesChatSession
165133
do
166134
local _class_0
@@ -270,6 +238,5 @@ do
270238
ResponsesChatSession = _class_0
271239
end
272240
return {
273-
ResponsesChatSession = ResponsesChatSession,
274-
create_response_stream_filter = create_response_stream_filter
241+
ResponsesChatSession = ResponsesChatSession
275242
}

openai/responses.moon

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -100,39 +100,6 @@ parse_responses_response = types.partial {
100100
status: empty + types.string\tag "status"
101101
}
102102

103-
-- creates a ltn12 compatible filter function that will call chunk_callback
104-
-- for each parsed json chunk from the server-sent events api response
105-
create_response_stream_filter = (chunk_callback) ->
106-
assert types.function(chunk_callback), "Must provide chunk_callback function when streaming response"
107-
108-
buffer = ""
109-
110-
(...) ->
111-
chunk = ...
112-
113-
if type(chunk) == "string"
114-
buffer ..= chunk
115-
116-
while true
117-
newline_pos = buffer\find "\n"
118-
break unless newline_pos
119-
120-
line = buffer\sub 1, newline_pos - 1
121-
buffer = buffer\sub newline_pos + 1
122-
123-
line = line\gsub "%s*$", "" -- trim trailing whitespace
124-
125-
if line\match "^data: "
126-
json_data = line\sub 7 -- Remove "data: " prefix
127-
128-
if json_data != "[DONE]"
129-
success, parsed = pcall -> cjson.decode json_data
130-
if success
131-
chunk_callback parsed
132-
133-
...
134-
135-
136103
-- A client side chat session backed by the responses API
137104
class ResponsesChatSession
138105
new: (@client, @opts={}) =>
@@ -216,5 +183,4 @@ class ResponsesChatSession
216183

217184
{
218185
:ResponsesChatSession
219-
:create_response_stream_filter
220186
}

openai/sse.lua

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
local cjson = require("cjson")
2+
local types
3+
types = require("tableshape").types
4+
local create_stream_filter
5+
create_stream_filter = function(chunk_callback)
6+
assert(types["function"](chunk_callback), "Must provide chunk_callback function when streaming response")
7+
local buffer = ""
8+
return function(...)
9+
local chunk = ...
10+
if type(chunk) == "string" then
11+
buffer = buffer .. chunk
12+
while true do
13+
local newline_pos = buffer:find("\n")
14+
if not (newline_pos) then
15+
break
16+
end
17+
local line = buffer:sub(1, newline_pos - 1)
18+
buffer = buffer:sub(newline_pos + 1)
19+
line = line:gsub("^%s+", "")
20+
line = line:gsub("%s*$", "")
21+
if line:match("^data: ") then
22+
local json_data = line:sub(7)
23+
if json_data ~= "[DONE]" then
24+
local success, parsed = pcall(function()
25+
return cjson.decode(json_data)
26+
end)
27+
if success then
28+
chunk_callback(parsed)
29+
end
30+
end
31+
end
32+
end
33+
end
34+
return ...
35+
end
36+
end
37+
return {
38+
create_stream_filter = create_stream_filter
39+
}

0 commit comments

Comments
 (0)
X Tutup