Skip to content

Optimize the batching logic of the embeddings #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: pypi/0.0.0-alpha
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions backend/llm_model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
from config.db_config import db, db_session_manager
from config.logging_config import logger

# OpenAI / Azure-OpenAI allows up to 300 000 tokens per embedding request.
# Leave some headroom to reduce the chance of repeated retries at the limit.
_MAX_TOKENS_PER_REQ = 290_000


def _estimate_tokens(text: str) -> int:
"""
Roughly estimate how many tokens `text` consumes.
For English models, on average 1 token ≈ 4 characters; for Chinese, 1 token ≈ 1.3–2 characters.
We use a compromise value of 3.5 characters per token to ensure a safer upper-bound estimate.
"""
return max(1, int(len(text) / 3.5))


class EmbeddingManager:
"""Embedding Manager"""
Expand Down Expand Up @@ -144,9 +157,31 @@ async def _get_embeddings_with_context(text: Union[str, List[str]], model_name:
if isinstance(text, str):
embedding = await embedding_model.aembed_query(text[:8192])
else:
text = [t[:8192] for t in text]
embedding = await embedding_model.aembed_documents(text)
# First, trim each text to 8 192 characters
texts = [t[:8192] for t in text]

# —— Batching logic —— #
batches, cur_batch, cur_tokens = [], [], 0
for t in texts:
tok = _estimate_tokens(t)
# If adding `t` would exceed the per-request token limit, finalize the current batch
if cur_batch and cur_tokens + tok > _MAX_TOKENS_PER_REQ:
batches.append(cur_batch)
cur_batch, cur_tokens = [], 0
cur_batch.append(t)
cur_tokens += tok
if cur_batch: # Process the last batch
batches.append(cur_batch)

# Send requests sequentially to preserve output order
embedding = []
for bt in batches:
bt_emb = await embedding_model.aembed_documents(bt)
embedding.extend(bt_emb)

return np.array(embedding)

except Exception as e:
logger.error(f"Failed to generate Embedding: {str(e)}")
raise