diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go index f0f2c3338..32b3a8ac4 100644 --- a/core/config/model_config_loader.go +++ b/core/config/model_config_loader.go @@ -287,18 +287,20 @@ func (bcl *ModelConfigLoader) Preload(modelPath string) error { if config.IsModelURL() { modelFileName := config.ModelFileName() uri := downloader.URI(config.Model) - // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { - err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status) - if err != nil { - return err + if uri.ResolveURL() != config.Model { + // check if file exists + if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { + err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status) + if err != nil { + return err + } } - } - cc := bcl.configs[i] - c := &cc - c.PredictionOptions.Model = modelFileName - bcl.configs[i] = *c + cc := bcl.configs[i] + c := &cc + c.PredictionOptions.Model = modelFileName + bcl.configs[i] = *c + } } if config.IsMMProjURL() { diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 0129c5fdc..3a23589d1 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -214,16 +214,26 @@ func (s URI) ResolveURL() string { repository = strings.Replace(repository, HuggingFacePrefix2, "", 1) // convert repository to a full URL. // e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf - owner := strings.Split(repository, "/")[0] - repo := strings.Split(repository, "/")[1] + + repoPieces := strings.Split(repository, "/") + repoID := strings.Split(repository, "@") + if len(repoPieces) < 3 { + return string(s) + } + + owner := repoPieces[0] + repo := repoPieces[1] branch := "main" - if strings.Contains(repo, "@") { - branch = strings.Split(repository, "@")[1] - } - filepath := strings.Split(repository, "/")[2] - if strings.Contains(filepath, "@") { - filepath = strings.Split(filepath, "@")[0] + filepath := repoPieces[2] + + if len(repoID) > 1 { + if strings.Contains(repo, "@") { + branch = repoID[1] + } + if strings.Contains(filepath, "@") { + filepath = repoID[2] + } } return fmt.Sprintf("%s/%s/%s/resolve/%s/%s", HF_ENDPOINT, owner, repo, branch, filepath)