feat(api): Add transcribe response format request parameter & adjust STT backends (#8318)
* WIP response format implementation for audio transcriptions (cherry picked from commit e271dd764bbc13846accf3beb8b6522153aa276f) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Rework transcript response_format and add more formats (cherry picked from commit 6a93a8f63e2ee5726bca2980b0c9cf4ef8b7aeb8) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Add test and replace go-openai package with official openai go client (cherry picked from commit f25d1a04e46526429c89db4c739e1e65942ca893) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Fix faster-whisper backend and refactor transcription formatting to also work on CLI Signed-off-by: Andres Smith <andressmithdev@pm.me> (cherry picked from commit 69a93977d5e113eb7172bd85a0f918592d3d2168) Signed-off-by: Andres Smith <andressmithdev@pm.me> --------- Signed-off-by: Andres Smith <andressmithdev@pm.me> Co-authored-by: nanoandrew4 <nanoandrew4@gmail.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -36,6 +36,8 @@ LocalAI
|
|||||||
models/*
|
models/*
|
||||||
test-models/
|
test-models/
|
||||||
test-dir/
|
test-dir/
|
||||||
|
tests/e2e-aio/backends
|
||||||
|
tests/e2e-aio/models
|
||||||
|
|
||||||
release/
|
release/
|
||||||
|
|
||||||
|
|||||||
@@ -130,8 +130,9 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
|||||||
segments := []*pb.TranscriptSegment{}
|
segments := []*pb.TranscriptSegment{}
|
||||||
text := ""
|
text := ""
|
||||||
for i := range int(segsLen) {
|
for i := range int(segsLen) {
|
||||||
s := CppGetSegmentStart(i)
|
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
|
||||||
t := CppGetSegmentEnd(i)
|
s := CppGetSegmentStart(i) * (10000000)
|
||||||
|
t := CppGetSegmentEnd(i) * (10000000)
|
||||||
txt := strings.Clone(CppGetSegmentText(i))
|
txt := strings.Clone(CppGetSegmentText(i))
|
||||||
tokens := make([]int32, CppNTokens(i))
|
tokens := make([]int32, CppNTokens(i))
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
device = "mps"
|
device = "mps"
|
||||||
try:
|
try:
|
||||||
print("Preparing models, please wait", file=sys.stderr)
|
print("Preparing models, please wait", file=sys.stderr)
|
||||||
self.model = WhisperModel(request.Model, device=device, compute_type="float16")
|
self.model = WhisperModel(request.Model, device=device, compute_type="default")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
# Implement your logic here for the LoadModel service
|
# Implement your logic here for the LoadModel service
|
||||||
@@ -55,11 +55,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
id = 0
|
id = 0
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=segment.start, end=segment.end, text=segment.text))
|
resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=int(segment.start)*1e9, end=int(segment.end)*1e9, text=segment.text))
|
||||||
text += segment.text
|
text += segment.text
|
||||||
id += 1
|
id += 1
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
|
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
|
||||||
|
raise err
|
||||||
|
|
||||||
return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
|
return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
|
|
||||||
if modelConfig.Backend == "" {
|
if modelConfig.Backend == "" {
|
||||||
modelConfig.Backend = model.WhisperBackend
|
modelConfig.Backend = model.WhisperBackend
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,32 +2,42 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/mudler/LocalAI/core/backend"
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/format"
|
||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TranscriptCMD struct {
|
type TranscriptCMD struct {
|
||||||
Filename string `arg:""`
|
Filename string `arg:"" name:"file" help:"Audio file to transcribe" type:"path"`
|
||||||
|
|
||||||
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
|
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
|
||||||
Model string `short:"m" required:"" help:"Model name to run the TTS"`
|
Model string `short:"m" required:"" help:"Model name to run the TTS"`
|
||||||
Language string `short:"l" help:"Language of the audio file"`
|
Language string `short:"l" help:"Language of the audio file"`
|
||||||
Translate bool `short:"c" help:"Translate the transcription to english"`
|
Translate bool `short:"c" help:"Translate the transcription to English"`
|
||||||
Diarize bool `short:"d" help:"Mark speaker turns"`
|
Diarize bool `short:"d" help:"Mark speaker turns"`
|
||||||
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
|
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
|
||||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
|
||||||
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||||
|
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||||
|
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
|
||||||
|
ResponseFormat schema.TranscriptionResponseFormatType `short:"f" default:"" help:"Response format for Whisper models, can be one of (txt, lrc, srt, vtt, json, json_verbose)"`
|
||||||
|
PrettyPrint bool `help:"Used with response_format json or json_verbose for pretty printing"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||||
systemState, err := system.GetSystemState(
|
systemState, err := system.GetSystemState(
|
||||||
|
system.WithBackendPath(t.BackendsPath),
|
||||||
system.WithModelPath(t.ModelsPath),
|
system.WithModelPath(t.ModelsPath),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -40,6 +50,11 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
|||||||
|
|
||||||
cl := config.NewModelConfigLoader(t.ModelsPath)
|
cl := config.NewModelConfigLoader(t.ModelsPath)
|
||||||
ml := model.NewModelLoader(systemState)
|
ml := model.NewModelLoader(systemState)
|
||||||
|
|
||||||
|
if err := gallery.RegisterBackends(systemState, ml); err != nil {
|
||||||
|
xlog.Error("error registering external backends", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil {
|
if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -62,8 +77,29 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, segment := range tr.Segments {
|
|
||||||
fmt.Println(segment.Start.String(), "-", segment.Text)
|
switch t.ResponseFormat {
|
||||||
|
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
|
||||||
|
fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat))
|
||||||
|
case schema.TranscriptionResponseFormatJson:
|
||||||
|
tr.Segments = nil
|
||||||
|
fallthrough
|
||||||
|
case schema.TranscriptionResponseFormatJsonVerbose:
|
||||||
|
var mtr []byte
|
||||||
|
var err error
|
||||||
|
if t.PrettyPrint {
|
||||||
|
mtr, err = json.MarshalIndent(tr, "", " ")
|
||||||
|
} else {
|
||||||
|
mtr, err = json.Marshal(tr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println(string(mtr))
|
||||||
|
default:
|
||||||
|
for _, segment := range tr.Segments {
|
||||||
|
fmt.Println(segment.Start.String(), "-", strings.TrimSpace(segment.Text))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/http/middleware"
|
"github.com/mudler/LocalAI/core/http/middleware"
|
||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
"github.com/mudler/LocalAI/pkg/format"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
@@ -38,6 +40,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
|||||||
|
|
||||||
diarize := c.FormValue("diarize") != "false"
|
diarize := c.FormValue("diarize") != "false"
|
||||||
prompt := c.FormValue("prompt")
|
prompt := c.FormValue("prompt")
|
||||||
|
responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format"))
|
||||||
|
|
||||||
// retrieve the file data from the request
|
// retrieve the file data from the request
|
||||||
file, err := c.FormFile("file")
|
file, err := c.FormFile("file")
|
||||||
@@ -76,7 +79,17 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
|||||||
}
|
}
|
||||||
|
|
||||||
xlog.Debug("Transcribed", "transcription", tr)
|
xlog.Debug("Transcribed", "transcription", tr)
|
||||||
// TODO: handle different outputs here
|
|
||||||
return c.JSON(http.StatusOK, tr)
|
switch responseFormat {
|
||||||
|
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt:
|
||||||
|
return c.String(http.StatusOK, format.TranscriptionResponse(tr, responseFormat))
|
||||||
|
case schema.TranscriptionResponseFormatJson:
|
||||||
|
tr.Segments = nil
|
||||||
|
fallthrough
|
||||||
|
case schema.TranscriptionResponseFormatJsonVerbose, "": // maintain backwards compatibility
|
||||||
|
return c.JSON(http.StatusOK, tr)
|
||||||
|
default:
|
||||||
|
return errors.New("invalid response_format")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,6 +107,17 @@ type ImageGenerationResponseFormat string
|
|||||||
|
|
||||||
type ChatCompletionResponseFormatType string
|
type ChatCompletionResponseFormatType string
|
||||||
|
|
||||||
|
type TranscriptionResponseFormatType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TranscriptionResponseFormatText = TranscriptionResponseFormatType("txt")
|
||||||
|
TranscriptionResponseFormatSrt = TranscriptionResponseFormatType("srt")
|
||||||
|
TranscriptionResponseFormatVtt = TranscriptionResponseFormatType("vtt")
|
||||||
|
TranscriptionResponseFormatLrc = TranscriptionResponseFormatType("lrc")
|
||||||
|
TranscriptionResponseFormatJson = TranscriptionResponseFormatType("json")
|
||||||
|
TranscriptionResponseFormatJsonVerbose = TranscriptionResponseFormatType("json_verbose")
|
||||||
|
)
|
||||||
|
|
||||||
type ChatCompletionResponseFormat struct {
|
type ChatCompletionResponseFormat struct {
|
||||||
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
|
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,6 @@ type TranscriptionSegment struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TranscriptionResult struct {
|
type TranscriptionResult struct {
|
||||||
Segments []TranscriptionSegment `json:"segments"`
|
Segments []TranscriptionSegment `json:"segments,omitempty"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,10 +18,6 @@ import (
|
|||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
YAML_EXTENSION = ".yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
// InstallModels will preload models from the given list of URLs and galleries
|
// InstallModels will preload models from the given list of URLs and galleries
|
||||||
// It will download the model if it is not already present in the model path
|
// It will download the model if it is not already present in the model path
|
||||||
// It will also try to resolve if the model is an embedded model YAML configuration
|
// It will also try to resolve if the model is an embedded model YAML configuration
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
3
go.mod
3
go.mod
@@ -67,9 +67,10 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/ghodss/yaml v1.0.0 // indirect
|
github.com/ghodss/yaml v1.0.0 // indirect
|
||||||
github.com/labstack/gommon v0.4.2 // indirect
|
github.com/labstack/gommon v0.4.2 // indirect
|
||||||
|
github.com/openai/openai-go/v3 v3.17.0 // indirect
|
||||||
github.com/swaggo/files/v2 v2.0.2 // indirect
|
github.com/swaggo/files/v2 v2.0.2 // indirect
|
||||||
github.com/tidwall/gjson v1.18.0 // indirect
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.2.0 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
github.com/tidwall/sjson v1.2.5 // indirect
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -565,6 +565,8 @@ github.com/onsi/ginkgo/v2 v2.27.5 h1:ZeVgZMx2PDMdJm/+w5fE/OyG6ILo1Y3e+QX4zSR0zTE
|
|||||||
github.com/onsi/ginkgo/v2 v2.27.5/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo=
|
github.com/onsi/ginkgo/v2 v2.27.5/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo=
|
||||||
github.com/onsi/gomega v1.39.0 h1:y2ROC3hKFmQZJNFeGAMeHZKkjBL65mIZcvrLQBF9k6Q=
|
github.com/onsi/gomega v1.39.0 h1:y2ROC3hKFmQZJNFeGAMeHZKkjBL65mIZcvrLQBF9k6Q=
|
||||||
github.com/onsi/gomega v1.39.0/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
|
github.com/onsi/gomega v1.39.0/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
|
||||||
|
github.com/openai/openai-go/v3 v3.17.0 h1:CfTkmQoItolSyW+bHOUF190KuX5+1Zv6MC0Gb4wAwy8=
|
||||||
|
github.com/openai/openai-go/v3 v3.17.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@@ -769,6 +771,8 @@ github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
|||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||||
|
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
|||||||
41
pkg/format/transcription.go
Normal file
41
pkg/format/transcription.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package format
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TranscriptionResponse(tr *schema.TranscriptionResult, resFmt schema.TranscriptionResponseFormatType) string {
|
||||||
|
var out string
|
||||||
|
if resFmt == schema.TranscriptionResponseFormatLrc {
|
||||||
|
out = "[by:LocalAI]\n[re:LocalAI]\n"
|
||||||
|
} else if resFmt == schema.TranscriptionResponseFormatVtt {
|
||||||
|
out = "WEBVTT"
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range tr.Segments {
|
||||||
|
switch resFmt {
|
||||||
|
case schema.TranscriptionResponseFormatLrc:
|
||||||
|
m := s.Start.Milliseconds()
|
||||||
|
out += fmt.Sprintf("\n[%02d:%02d:%02d] %s", m/60000, (m/1000)%60, (m%1000)/10, strings.TrimSpace(s.Text))
|
||||||
|
case schema.TranscriptionResponseFormatSrt:
|
||||||
|
out += fmt.Sprintf("\n\n%d\n%s --> %s\n%s", i+1, durationStr(s.Start, ','), durationStr(s.End, ','), strings.TrimSpace(s.Text))
|
||||||
|
case schema.TranscriptionResponseFormatVtt:
|
||||||
|
out += fmt.Sprintf("\n\n%s --> %s\n%s\n", durationStr(s.Start, '.'), durationStr(s.End, '.'), strings.TrimSpace(s.Text))
|
||||||
|
case schema.TranscriptionResponseFormatText:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
out += fmt.Sprintf("\n%s", strings.TrimSpace(s.Text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func durationStr(d time.Duration, millisSeparator rune) string {
|
||||||
|
m := d.Milliseconds()
|
||||||
|
return fmt.Sprintf("%02d:%02d:%02d%c%03d", m/3600000, m/60000, int(d.Seconds())%60, millisSeparator, m%1000)
|
||||||
|
}
|
||||||
@@ -17,7 +17,7 @@ const (
|
|||||||
LLamaCPP = "llama-cpp"
|
LLamaCPP = "llama-cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Aliases map[string]string = map[string]string{
|
var Aliases = map[string]string{
|
||||||
"go-llama": LLamaCPP,
|
"go-llama": LLamaCPP,
|
||||||
"llama": LLamaCPP,
|
"llama": LLamaCPP,
|
||||||
"embedded-store": LocalStoreBackend,
|
"embedded-store": LocalStoreBackend,
|
||||||
@@ -29,7 +29,7 @@ var Aliases map[string]string = map[string]string{
|
|||||||
"stablediffusion": StableDiffusionGGMLBackend,
|
"stablediffusion": StableDiffusionGGMLBackend,
|
||||||
}
|
}
|
||||||
|
|
||||||
var TypeAlias map[string]string = map[string]string{
|
var TypeAlias = map[string]string{
|
||||||
"sentencetransformers": "SentenceTransformer",
|
"sentencetransformers": "SentenceTransformer",
|
||||||
"huggingface-embeddings": "SentenceTransformer",
|
"huggingface-embeddings": "SentenceTransformer",
|
||||||
"mamba": "Mamba",
|
"mamba": "Mamba",
|
||||||
@@ -75,7 +75,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
|
|||||||
// Check if the backend is provided as external
|
// Check if the backend is provided as external
|
||||||
if uri, ok := ml.GetAllExternalBackends(o)[backend]; ok {
|
if uri, ok := ml.GetAllExternalBackends(o)[backend]; ok {
|
||||||
xlog.Debug("Loading external backend", "uri", uri)
|
xlog.Debug("Loading external backend", "uri", uri)
|
||||||
// check if uri is a file or a address
|
// check if uri is a file or an address
|
||||||
if fi, err := os.Stat(uri); err == nil {
|
if fi, err := os.Stat(uri); err == nil {
|
||||||
xlog.Debug("external backend is file", "file", fi)
|
xlog.Debug("external backend is file", "file", fi)
|
||||||
serverAddress, err := getFreeAddress()
|
serverAddress, err := getFreeAddress()
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ import (
|
|||||||
"github.com/docker/go-connections/nat"
|
"github.com/docker/go-connections/nat"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/openai/openai-go/v3"
|
||||||
|
"github.com/openai/openai-go/v3/option"
|
||||||
"github.com/testcontainers/testcontainers-go"
|
"github.com/testcontainers/testcontainers-go"
|
||||||
"github.com/testcontainers/testcontainers-go/wait"
|
"github.com/testcontainers/testcontainers-go/wait"
|
||||||
)
|
)
|
||||||
|
|
||||||
var container testcontainers.Container
|
var container testcontainers.Container
|
||||||
var client *openai.Client
|
var client openai.Client
|
||||||
|
|
||||||
var containerImage = os.Getenv("LOCALAI_IMAGE")
|
var containerImage = os.Getenv("LOCALAI_IMAGE")
|
||||||
var containerImageTag = os.Getenv("LOCALAI_IMAGE_TAG")
|
var containerImageTag = os.Getenv("LOCALAI_IMAGE_TAG")
|
||||||
@@ -37,26 +38,22 @@ func TestLocalAI(t *testing.T) {
|
|||||||
|
|
||||||
var _ = BeforeSuite(func() {
|
var _ = BeforeSuite(func() {
|
||||||
|
|
||||||
var defaultConfig openai.ClientConfig
|
|
||||||
if apiEndpoint == "" {
|
if apiEndpoint == "" {
|
||||||
startDockerImage()
|
startDockerImage()
|
||||||
apiPort, err := container.MappedPort(context.Background(), nat.Port(defaultApiPort))
|
apiPort, err := container.MappedPort(context.Background(), defaultApiPort)
|
||||||
Expect(err).To(Not(HaveOccurred()))
|
Expect(err).To(Not(HaveOccurred()))
|
||||||
|
|
||||||
defaultConfig = openai.DefaultConfig(apiKey)
|
|
||||||
apiEndpoint = "http://localhost:" + apiPort.Port() + "/v1" // So that other tests can reference this value safely.
|
apiEndpoint = "http://localhost:" + apiPort.Port() + "/v1" // So that other tests can reference this value safely.
|
||||||
defaultConfig.BaseURL = apiEndpoint
|
|
||||||
} else {
|
} else {
|
||||||
GinkgoWriter.Printf("docker apiEndpoint set from env: %q\n", apiEndpoint)
|
GinkgoWriter.Printf("docker apiEndpoint set from env: %q\n", apiEndpoint)
|
||||||
defaultConfig = openai.DefaultConfig(apiKey)
|
|
||||||
defaultConfig.BaseURL = apiEndpoint
|
|
||||||
}
|
}
|
||||||
|
opts := []option.RequestOption{option.WithAPIKey(apiKey), option.WithBaseURL(apiEndpoint)}
|
||||||
|
|
||||||
// Wait for API to be ready
|
// Wait for API to be ready
|
||||||
client = openai.NewClientWithConfig(defaultConfig)
|
client = openai.NewClient(opts...)
|
||||||
|
|
||||||
Eventually(func() error {
|
Eventually(func() error {
|
||||||
_, err := client.ListModels(context.TODO())
|
_, err := client.Models.List(context.TODO())
|
||||||
return err
|
return err
|
||||||
}, "50m").ShouldNot(HaveOccurred())
|
}, "50m").ShouldNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/schema"
|
"github.com/mudler/LocalAI/core/schema"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/openai/openai-go/v3"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/openai/openai-go/v3/option"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("E2E test", func() {
|
var _ = Describe("E2E test", func() {
|
||||||
@@ -30,14 +30,13 @@ var _ = Describe("E2E test", func() {
|
|||||||
Context("text", func() {
|
Context("text", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
model := "gpt-4"
|
model := "gpt-4"
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionNewParams{
|
||||||
Model: model, Messages: []openai.ChatCompletionMessage{
|
Model: model,
|
||||||
{
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||||
Role: "user",
|
openai.UserMessage("How much is 2+2?"),
|
||||||
Content: "How much is 2+2?",
|
},
|
||||||
},
|
})
|
||||||
}})
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
||||||
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content))
|
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content))
|
||||||
@@ -46,39 +45,36 @@ var _ = Describe("E2E test", func() {
|
|||||||
|
|
||||||
Context("function calls", func() {
|
Context("function calls", func() {
|
||||||
It("correctly invoke", func() {
|
It("correctly invoke", func() {
|
||||||
params := jsonschema.Definition{
|
params := openai.FunctionParameters{
|
||||||
Type: jsonschema.Object,
|
"type": "object",
|
||||||
Properties: map[string]jsonschema.Definition{
|
"properties": map[string]any{
|
||||||
"location": {
|
"location": map[string]string{
|
||||||
Type: jsonschema.String,
|
"type": "string",
|
||||||
Description: "The city and state, e.g. San Francisco, CA",
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
},
|
},
|
||||||
"unit": {
|
"unit": map[string]any{
|
||||||
Type: jsonschema.String,
|
"type": "string",
|
||||||
Enum: []string{"celsius", "fahrenheit"},
|
"enum": []string{"celsius", "fahrenheit"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Required: []string{"location"},
|
"required": []string{"location"},
|
||||||
}
|
}
|
||||||
|
|
||||||
f := openai.FunctionDefinition{
|
tool := openai.ChatCompletionToolUnionParam{
|
||||||
Name: "get_current_weather",
|
OfFunction: &openai.ChatCompletionFunctionToolParam{
|
||||||
Description: "Get the current weather in a given location",
|
Function: openai.FunctionDefinitionParam{
|
||||||
Parameters: params,
|
Name: "get_current_weather",
|
||||||
}
|
Description: openai.String("Get the current weather in a given location"),
|
||||||
t := openai.Tool{
|
Parameters: params,
|
||||||
Type: openai.ToolTypeFunction,
|
},
|
||||||
Function: &f,
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
dialogue := []openai.ChatCompletionMessage{
|
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||||
{Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"},
|
openai.ChatCompletionNewParams{
|
||||||
}
|
Model: openai.ChatModelGPT4,
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage("What is the weather in Boston today?")},
|
||||||
openai.ChatCompletionRequest{
|
Tools: []openai.ChatCompletionToolUnionParam{tool},
|
||||||
Model: openai.GPT4,
|
|
||||||
Messages: dialogue,
|
|
||||||
Tools: []openai.Tool{t},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -90,23 +86,21 @@ var _ = Describe("E2E test", func() {
|
|||||||
Expect(msg.ToolCalls[0].Function.Arguments).To(ContainSubstring("Boston"), fmt.Sprint(msg.ToolCalls[0].Function.Arguments))
|
Expect(msg.ToolCalls[0].Function.Arguments).To(ContainSubstring("Boston"), fmt.Sprint(msg.ToolCalls[0].Function.Arguments))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("json", func() {
|
Context("json", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
model := "gpt-4"
|
model := "gpt-4"
|
||||||
|
|
||||||
req := openai.ChatCompletionRequest{
|
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||||
ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject},
|
openai.ChatCompletionNewParams{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: []openai.ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||||
{
|
openai.UserMessage("Generate a JSON object of an animal with 'name', 'gender' and 'legs' fields"),
|
||||||
|
|
||||||
Role: "user",
|
|
||||||
Content: "Generate a JSON object of an animal with 'name', 'gender' and 'legs' fields",
|
|
||||||
},
|
},
|
||||||
},
|
ResponseFormat: openai.ChatCompletionNewParamsResponseFormatUnion{
|
||||||
}
|
OfJSONObject: &openai.ResponseFormatJSONObjectParam{},
|
||||||
|
},
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(), req)
|
})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
||||||
|
|
||||||
@@ -121,23 +115,23 @@ var _ = Describe("E2E test", func() {
|
|||||||
|
|
||||||
Context("images", func() {
|
Context("images", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
req := openai.ImageRequest{
|
resp, err := client.Images.Generate(context.TODO(),
|
||||||
Prompt: "test",
|
openai.ImageGenerateParams{
|
||||||
Quality: "1",
|
Prompt: "test",
|
||||||
Size: openai.CreateImageSize256x256,
|
Size: openai.ImageGenerateParamsSize256x256,
|
||||||
}
|
Quality: openai.ImageGenerateParamsQualityLow,
|
||||||
resp, err := client.CreateImage(context.TODO(), req)
|
})
|
||||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request %+v", req))
|
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("error sending image request"))
|
||||||
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||||
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
||||||
})
|
})
|
||||||
It("correctly changes the response format to url", func() {
|
It("correctly changes the response format to url", func() {
|
||||||
resp, err := client.CreateImage(context.TODO(),
|
resp, err := client.Images.Generate(context.TODO(),
|
||||||
openai.ImageRequest{
|
openai.ImageGenerateParams{
|
||||||
Prompt: "test",
|
Prompt: "test",
|
||||||
Size: openai.CreateImageSize256x256,
|
Size: openai.ImageGenerateParamsSize256x256,
|
||||||
Quality: "1",
|
ResponseFormat: openai.ImageGenerateParamsResponseFormatURL,
|
||||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
Quality: openai.ImageGenerateParamsQualityLow,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -145,12 +139,11 @@ var _ = Describe("E2E test", func() {
|
|||||||
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
|
||||||
})
|
})
|
||||||
It("correctly changes the response format to base64", func() {
|
It("correctly changes the response format to base64", func() {
|
||||||
resp, err := client.CreateImage(context.TODO(),
|
resp, err := client.Images.Generate(context.TODO(),
|
||||||
openai.ImageRequest{
|
openai.ImageGenerateParams{
|
||||||
Prompt: "test",
|
Prompt: "test",
|
||||||
Size: openai.CreateImageSize256x256,
|
Size: openai.ImageGenerateParamsSize256x256,
|
||||||
Quality: "1",
|
ResponseFormat: openai.ImageGenerateParamsResponseFormatB64JSON,
|
||||||
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -158,22 +151,27 @@ var _ = Describe("E2E test", func() {
|
|||||||
Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON))
|
Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("embeddings", func() {
|
Context("embeddings", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
resp, err := client.CreateEmbeddings(context.TODO(),
|
resp, err := client.Embeddings.New(context.TODO(),
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingNewParams{
|
||||||
Input: []string{"doc"},
|
Input: openai.EmbeddingNewParamsInputUnion{
|
||||||
Model: openai.AdaEmbeddingV2,
|
OfArrayOfStrings: []string{"doc"},
|
||||||
|
},
|
||||||
|
Model: openai.EmbeddingModelTextEmbeddingAda002,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
|
||||||
Expect(resp.Data[0].Embedding).ToNot(BeEmpty())
|
Expect(resp.Data[0].Embedding).ToNot(BeEmpty())
|
||||||
|
|
||||||
resp2, err := client.CreateEmbeddings(context.TODO(),
|
resp2, err := client.Embeddings.New(context.TODO(),
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingNewParams{
|
||||||
Input: []string{"cat"},
|
Input: openai.EmbeddingNewParamsInputUnion{
|
||||||
Model: openai.AdaEmbeddingV2,
|
OfArrayOfStrings: []string{"cat"},
|
||||||
|
},
|
||||||
|
Model: openai.EmbeddingModelTextEmbeddingAda002,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -181,10 +179,12 @@ var _ = Describe("E2E test", func() {
|
|||||||
Expect(resp2.Data[0].Embedding).ToNot(BeEmpty())
|
Expect(resp2.Data[0].Embedding).ToNot(BeEmpty())
|
||||||
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[0].Embedding))
|
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[0].Embedding))
|
||||||
|
|
||||||
resp3, err := client.CreateEmbeddings(context.TODO(),
|
resp3, err := client.Embeddings.New(context.TODO(),
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingNewParams{
|
||||||
Input: []string{"doc", "cat"},
|
Input: openai.EmbeddingNewParamsInputUnion{
|
||||||
Model: openai.AdaEmbeddingV2,
|
OfArrayOfStrings: []string{"doc", "cat"},
|
||||||
|
},
|
||||||
|
Model: openai.EmbeddingModelTextEmbeddingAda002,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -195,66 +195,101 @@ var _ = Describe("E2E test", func() {
|
|||||||
Expect(resp3.Data[0].Embedding).ToNot(Equal(resp3.Data[1].Embedding))
|
Expect(resp3.Data[0].Embedding).ToNot(Equal(resp3.Data[1].Embedding))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("vision", func() {
|
Context("vision", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
model := "gpt-4o"
|
model := "gpt-4o"
|
||||||
resp, err := client.CreateChatCompletion(context.TODO(),
|
resp, err := client.Chat.Completions.New(context.TODO(),
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionNewParams{
|
||||||
Model: model, Messages: []openai.ChatCompletionMessage{
|
Model: model,
|
||||||
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||||
{
|
{
|
||||||
|
OfUser: &openai.ChatCompletionUserMessageParam{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
MultiContent: []openai.ChatMessagePart{
|
Content: openai.ChatCompletionUserMessageParamContentUnion{
|
||||||
{
|
OfArrayOfContentParts: []openai.ChatCompletionContentPartUnionParam{
|
||||||
Type: openai.ChatMessagePartTypeText,
|
{
|
||||||
Text: "What is in the image?",
|
OfText: &openai.ChatCompletionContentPartTextParam{
|
||||||
},
|
Type: "text",
|
||||||
{
|
Text: "What is in the image?",
|
||||||
Type: openai.ChatMessagePartTypeImageURL,
|
},
|
||||||
ImageURL: &openai.ChatMessageImageURL{
|
},
|
||||||
URL: "https://picsum.photos/id/22/4434/3729",
|
{
|
||||||
Detail: openai.ImageURLDetailLow,
|
OfImageURL: &openai.ChatCompletionContentPartImageParam{
|
||||||
|
ImageURL: openai.ChatCompletionContentPartImageImageURLParam{
|
||||||
|
URL: "https://picsum.photos/id/22/4434/3729",
|
||||||
|
Detail: "low",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}})
|
},
|
||||||
|
})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp))
|
||||||
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("man"), ContainSubstring("road")), fmt.Sprint(resp.Choices[0].Message.Content))
|
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("man"), ContainSubstring("road")), fmt.Sprint(resp.Choices[0].Message.Content))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("text to audio", func() {
|
Context("text to audio", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
|
res, err := client.Audio.Speech.New(context.Background(), openai.AudioSpeechNewParams{
|
||||||
Model: openai.TTSModel1,
|
Model: openai.SpeechModelTTS1,
|
||||||
Input: "Hello!",
|
Input: "Hello!",
|
||||||
Voice: openai.VoiceAlloy,
|
Voice: openai.AudioSpeechNewParamsVoiceAlloy,
|
||||||
})
|
})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
defer res.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
_, err = io.ReadAll(res)
|
_, err = io.ReadAll(res.Body)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("audio to text", func() {
|
Context("audio to text", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
|
|
||||||
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
|
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
|
||||||
file, err := downloadHttpFile(downloadURL)
|
file, err := downloadHttpFile(downloadURL)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
req := openai.AudioRequest{
|
fileHandle, err := os.Open(file)
|
||||||
Model: openai.Whisper1,
|
|
||||||
FilePath: file,
|
|
||||||
}
|
|
||||||
resp, err := client.CreateTranscription(context.Background(), req)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer fileHandle.Close()
|
||||||
|
|
||||||
|
transcriptionResp, err := client.Audio.Transcriptions.New(context.Background(), openai.AudioTranscriptionNewParams{
|
||||||
|
Model: openai.AudioModelWhisper1,
|
||||||
|
File: fileHandle,
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
resp := transcriptionResp.AsTranscription()
|
||||||
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
|
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("with VTT format", func() {
|
||||||
|
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
|
||||||
|
file, err := downloadHttpFile(downloadURL)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
fileHandle, err := os.Open(file)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
defer fileHandle.Close()
|
||||||
|
|
||||||
|
var resp string
|
||||||
|
_, err = client.Audio.Transcriptions.New(context.Background(), openai.AudioTranscriptionNewParams{
|
||||||
|
Model: openai.AudioModelWhisper1,
|
||||||
|
File: fileHandle,
|
||||||
|
ResponseFormat: openai.AudioResponseFormatVTT,
|
||||||
|
}, option.WithResponseBodyInto(&resp))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(resp).To(ContainSubstring("This is the"), resp)
|
||||||
|
Expect(resp).To(ContainSubstring("WEBVTT"), resp)
|
||||||
|
Expect(resp).To(ContainSubstring("00:00:00.000 -->"), resp)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("vad", func() {
|
Context("vad", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
modelName := "silero-vad"
|
modelName := "silero-vad"
|
||||||
@@ -283,6 +318,7 @@ var _ = Describe("E2E test", func() {
|
|||||||
Expect(deserializedResponse.Segments).ToNot(BeZero())
|
Expect(deserializedResponse.Segments).ToNot(BeZero())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("reranker", func() {
|
Context("reranker", func() {
|
||||||
It("correctly", func() {
|
It("correctly", func() {
|
||||||
modelName := "jina-reranker-v1-base-en"
|
modelName := "jina-reranker-v1-base-en"
|
||||||
@@ -317,7 +353,6 @@ var _ = Describe("E2E test", func() {
|
|||||||
Expect(err).To(BeNil())
|
Expect(err).To(BeNil())
|
||||||
Expect(deserializedResponse).ToNot(BeZero())
|
Expect(deserializedResponse).ToNot(BeZero())
|
||||||
Expect(deserializedResponse.Model).To(Equal(modelName))
|
Expect(deserializedResponse.Model).To(Equal(modelName))
|
||||||
//Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
|
|
||||||
Expect(len(deserializedResponse.Results)).To(Equal(expectResults))
|
Expect(len(deserializedResponse.Results)).To(Equal(expectResults))
|
||||||
// Assert that relevance scores are in decreasing order
|
// Assert that relevance scores are in decreasing order
|
||||||
for i := 1; i < len(deserializedResponse.Results); i++ {
|
for i := 1; i < len(deserializedResponse.Results); i++ {
|
||||||
|
|||||||
@@ -17,14 +17,14 @@ import (
|
|||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/phayes/freeport"
|
"github.com/phayes/freeport"
|
||||||
"github.com/sashabaranov/go-openai"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"github.com/mudler/xlog"
|
"github.com/mudler/xlog"
|
||||||
|
"github.com/openai/openai-go/v3"
|
||||||
|
"github.com/openai/openai-go/v3/option"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
localAIURL string
|
|
||||||
anthropicBaseURL string
|
anthropicBaseURL string
|
||||||
tmpDir string
|
tmpDir string
|
||||||
backendPath string
|
backendPath string
|
||||||
@@ -33,7 +33,7 @@ var (
|
|||||||
app *echo.Echo
|
app *echo.Echo
|
||||||
appCtx context.Context
|
appCtx context.Context
|
||||||
appCancel context.CancelFunc
|
appCancel context.CancelFunc
|
||||||
client *openai.Client
|
client openai.Client
|
||||||
apiPort int
|
apiPort int
|
||||||
apiURL string
|
apiURL string
|
||||||
mockBackendPath string
|
mockBackendPath string
|
||||||
@@ -129,7 +129,6 @@ var _ = BeforeSuite(func() {
|
|||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
apiPort = port
|
apiPort = port
|
||||||
apiURL = fmt.Sprintf("http://127.0.0.1:%d/v1", apiPort)
|
apiURL = fmt.Sprintf("http://127.0.0.1:%d/v1", apiPort)
|
||||||
localAIURL = apiURL
|
|
||||||
// Anthropic SDK appends /v1/messages to base URL; use base without /v1 so requests go to /v1/messages
|
// Anthropic SDK appends /v1/messages to base URL; use base without /v1 so requests go to /v1/messages
|
||||||
anthropicBaseURL = fmt.Sprintf("http://127.0.0.1:%d", apiPort)
|
anthropicBaseURL = fmt.Sprintf("http://127.0.0.1:%d", apiPort)
|
||||||
|
|
||||||
@@ -141,12 +140,10 @@ var _ = BeforeSuite(func() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Wait for server to be ready
|
// Wait for server to be ready
|
||||||
defaultConfig := openai.DefaultConfig("")
|
client = openai.NewClient(option.WithBaseURL(apiURL))
|
||||||
defaultConfig.BaseURL = apiURL
|
|
||||||
client = openai.NewClientWithConfig(defaultConfig)
|
|
||||||
|
|
||||||
Eventually(func() error {
|
Eventually(func() error {
|
||||||
_, err := client.ListModels(context.TODO())
|
_, err := client.Models.List(context.TODO())
|
||||||
return err
|
return err
|
||||||
}, "2m").ShouldNot(HaveOccurred())
|
}, "2m").ShouldNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,22 +9,19 @@ import (
|
|||||||
|
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/openai/openai-go/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
||||||
Describe("Text Generation APIs", func() {
|
Describe("Text Generation APIs", func() {
|
||||||
Context("Predict (Chat Completions)", func() {
|
Context("Predict (Chat Completions)", func() {
|
||||||
It("should return mocked response", func() {
|
It("should return mocked response", func() {
|
||||||
resp, err := client.CreateChatCompletion(
|
resp, err := client.Chat.Completions.New(
|
||||||
context.TODO(),
|
context.TODO(),
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionNewParams{
|
||||||
Model: "mock-model",
|
Model: "mock-model",
|
||||||
Messages: []openai.ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||||
{
|
openai.UserMessage("Hello"),
|
||||||
Role: "user",
|
|
||||||
Content: "Hello",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -36,31 +33,23 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
|||||||
|
|
||||||
Context("PredictStream (Streaming Chat Completions)", func() {
|
Context("PredictStream (Streaming Chat Completions)", func() {
|
||||||
It("should stream mocked tokens", func() {
|
It("should stream mocked tokens", func() {
|
||||||
stream, err := client.CreateChatCompletionStream(
|
stream := client.Chat.Completions.NewStreaming(
|
||||||
context.TODO(),
|
context.TODO(),
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionNewParams{
|
||||||
Model: "mock-model",
|
Model: "mock-model",
|
||||||
Messages: []openai.ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||||
{
|
openai.UserMessage("Hello"),
|
||||||
Role: "user",
|
|
||||||
Content: "Hello",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
hasContent := false
|
hasContent := false
|
||||||
for {
|
for stream.Next() {
|
||||||
response, err := stream.Recv()
|
response := stream.Current()
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if len(response.Choices) > 0 && response.Choices[0].Delta.Content != "" {
|
if len(response.Choices) > 0 && response.Choices[0].Delta.Content != "" {
|
||||||
hasContent = true
|
hasContent = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Expect(stream.Err()).ToNot(HaveOccurred())
|
||||||
Expect(hasContent).To(BeTrue())
|
Expect(hasContent).To(BeTrue())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -68,11 +57,13 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
|
|||||||
|
|
||||||
Describe("Embeddings API", func() {
|
Describe("Embeddings API", func() {
|
||||||
It("should return mocked embeddings", func() {
|
It("should return mocked embeddings", func() {
|
||||||
resp, err := client.CreateEmbeddings(
|
resp, err := client.Embeddings.New(
|
||||||
context.TODO(),
|
context.TODO(),
|
||||||
openai.EmbeddingRequest{
|
openai.EmbeddingNewParams{
|
||||||
Model: "mock-model",
|
Model: "mock-model",
|
||||||
Input: []string{"test"},
|
Input: openai.EmbeddingNewParamsInputUnion{
|
||||||
|
OfArrayOfStrings: []string{"test"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|||||||
Reference in New Issue
Block a user