Skip to content

Implement log probabilities for chat completions #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions api/options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package api

import "testing"

func TestValidateLogProbs(t *testing.T) {
cases := []struct {
name string
opts Options
wantErr bool
wantTop int
}{
{"disabled", Options{}, false, 0},
{"enabled default", Options{LogProbsEnabled: true}, false, 1},
{"enabled 3", Options{LogProbsEnabled: true, TopLogProbs: 3}, false, 3},
{"enabled max", Options{LogProbsEnabled: true, TopLogProbs: 5}, false, 5},
{"enabled too high", Options{LogProbsEnabled: true, TopLogProbs: 6}, true, 6},
{"enabled negative", Options{LogProbsEnabled: true, TopLogProbs: -1}, true, -1},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := tc.opts.ValidateLogProbs()
if tc.wantErr {
if err == nil {
t.Fatalf("expected error, got nil")
}
} else if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if !tc.wantErr && tc.opts.LogProbsEnabled {
if tc.opts.TopLogProbs != tc.wantTop {
t.Fatalf("expected TopLogProbs %d, got %d", tc.wantTop, tc.opts.TopLogProbs)
}
}
})
}
}
45 changes: 45 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ type ChatResponse struct {

Done bool `json:"done"`

// Metrics contains timing and token usage details
Metrics

// LogProbs optionally contains token-level log probability information
LogProbs *LogProbs `json:"logprobs,omitempty"`
}

type Metrics struct {
Expand All @@ -204,6 +208,21 @@ type Metrics struct {
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}

// LogProbs represents token-level log probability data returned by the model
// following the schema used by the OpenAI API.
type LogProbs struct {
// TokenLogprobs is the log probability of each returned token in `Tokens`.
TokenLogprobs []float64 `json:"token_logprobs"`
// Tokens are the textual tokens that were generated.
Tokens []string `json:"tokens"`
// TokenIDs are the token IDs corresponding to `Tokens`.
TokenIDs []int `json:"token_ids"`
// TopLogprobs, if requested, gives the top-n log probabilities for each
// generated token. Each entry is a map from the candidate token string to
// its log probability.
TopLogprobs []map[string]float64 `json:"top_logprobs,omitempty"`
}

// Options specified in [GenerateRequest]. If you add a new option here, also
// add it to the API docs.
type Options struct {
Expand All @@ -226,6 +245,29 @@ type Options struct {
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
Stop []string `json:"stop,omitempty"`

// Experimental: enable return of token log probabilities
LogProbsEnabled bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
}

// ValidateLogProbs verifies that any logprob-related settings are in a valid
// range. The OpenAI API caps top_logprobs at 5, we follow the same limit.
// If LogProbsEnabled is true and TopLogProbs is zero, the method sets it to 1
// (the OpenAI default) so that callers don't have to specify both fields.
func (o *Options) ValidateLogProbs() error {
if !o.LogProbsEnabled {
return nil
}

if o.TopLogProbs == 0 {
o.TopLogProbs = 1
}

if o.TopLogProbs < 0 || o.TopLogProbs > 5 {
return fmt.Errorf("top_logprobs must be between 0 and 5")
}
return nil
}

// Runner options which must be set when the model is loaded into memory
Expand Down Expand Up @@ -453,6 +495,9 @@ type GenerateResponse struct {
Context []int `json:"context,omitempty"`

Metrics

// LogProbs optionally contains token-level probability data.
LogProbs *LogProbs `json:"logprobs,omitempty"`
}

// ModelDetails provides details about a model.
Expand Down
3 changes: 2 additions & 1 deletion llama/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,8 @@ type CompletionResponse struct {

func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
// initialize default inference options (use zero value; defaults applied downstream)
req.Options = Options{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
Expand Down
30 changes: 30 additions & 0 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,14 @@ type completion struct {
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`

// Log probability information returned by the server (if requested)
Logprobs struct {
Tokens []string `json:"tokens"`
TokenIDs []int `json:"token_ids"`
TokenLogprobs []float64 `json:"token_logprobs"`
TopLogprobs []map[string]float64 `json:"top_logprobs,omitempty"`
} `json:"logprobs,omitempty"`

Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
Expand All @@ -671,6 +679,9 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration

// Optional logprobs
LogProbs *api.LogProbs
}

func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
Expand Down Expand Up @@ -698,6 +709,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"cache_prompt": true,
}

// logprobs integration
if req.Options.LogProbsEnabled {
request["logprobs"] = true
}
if req.Options.TopLogProbs > 0 {
request["top_logprobs"] = req.Options.TopLogProbs
}

if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
Expand Down Expand Up @@ -818,8 +837,18 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}

if c.Content != "" {
var lp *api.LogProbs
if len(c.Logprobs.Tokens) > 0 || len(c.Logprobs.TokenIDs) > 0 {
lp = &api.LogProbs{
Tokens: c.Logprobs.Tokens,
TokenIDs: c.Logprobs.TokenIDs,
TokenLogprobs: c.Logprobs.TokenLogprobs,
TopLogprobs: c.Logprobs.TopLogprobs,
}
}
fn(CompletionResponse{
Content: c.Content,
LogProbs: lp,
})
}

Expand All @@ -836,6 +865,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
LogProbs: nil,
})
return nil
}
Expand Down
Loading