mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-22 18:17:39 +08:00
Merge #4304
4304: Add CUDA GPU support for Hugging Face embedders r=Kerollmops a=dureuill Adds a "cuda" feature to `milli`. Compiling with this feature requires that the CUDA support library be installed (see "with CUDA support" paragraph in https://huggingface.github.io/candle/guide/installation.html), and adds CUDA support to the `huggingFace` embedder. To enable GPU support, users will need to: 1. Have a compatible NVidia GPU under Linux 2. Follow [the guide](https://huggingface.github.io/candle/guide/installation.html) to install the CUDA dependencies 3. Compile Meilisearch with the `cuda` feature: `cargo build --release --features cuda` # Impact Enabling the CUDA feature allows to use an available GPU to compute embeddings with a `huggingFace` embedder. On an AWS Graviton 2, this yields a x3 - x5 improvement on indexing time. # Technical details - I had to change the CI so that the cuda feature is not included in the `Tests all features` workflow - To achieve that, I had to add a binary following the `cargo xtask` design pattern, to list all features excepted the cuda one. - I then changed the workflow accordingly (renamed to "Tests almost all features" 😉) - A test run of the new feature was done on a temporary version of this PR that had it enabled for PRs: [See the results here](https://github.com/meilisearch/meilisearch/actions/runs/7461331929/job/20301216732) Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
commit
b6fc181993
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[alias]
|
||||||
|
xtask = "run --package xtask --"
|
18
.github/workflows/test-suite.yml
vendored
18
.github/workflows/test-suite.yml
vendored
@ -82,7 +82,7 @@ jobs:
|
|||||||
args: --locked --release --all
|
args: --locked --release --all
|
||||||
|
|
||||||
test-all-features:
|
test-all-features:
|
||||||
name: Tests all features
|
name: Tests almost all features
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
container:
|
container:
|
||||||
# Use ubuntu-18.04 to compile with glibc 2.27, which are the production expectations
|
# Use ubuntu-18.04 to compile with glibc 2.27, which are the production expectations
|
||||||
@ -98,16 +98,12 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
toolchain: stable
|
toolchain: stable
|
||||||
override: true
|
override: true
|
||||||
- name: Run cargo build with all features
|
- name: Run cargo build with almost all features
|
||||||
uses: actions-rs/cargo@v1
|
run: |
|
||||||
with:
|
cargo build --workspace --locked --release --features "$(cargo xtask list-features --exclude-feature cuda)"
|
||||||
command: build
|
- name: Run cargo test with almost all features
|
||||||
args: --workspace --locked --release --all-features
|
run: |
|
||||||
- name: Run cargo test with all features
|
cargo test --workspace --locked --release --features "$(cargo xtask list-features --exclude-feature cuda)"
|
||||||
uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --workspace --locked --release --all-features
|
|
||||||
|
|
||||||
test-disabled-tokenization:
|
test-disabled-tokenization:
|
||||||
name: Test disabled tokenization
|
name: Test disabled tokenization
|
||||||
|
@ -75,6 +75,12 @@ If you get a "Too many open files" error you might want to increase the open fil
|
|||||||
ulimit -Sn 3000
|
ulimit -Sn 3000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Build tools
|
||||||
|
|
||||||
|
Meilisearch follows the [cargo xtask](https://github.com/matklad/cargo-xtask) workflow to provide some build tools.
|
||||||
|
|
||||||
|
Run `cargo xtask --help` from the root of the repository to find out what is available.
|
||||||
|
|
||||||
## Git Guidelines
|
## Git Guidelines
|
||||||
|
|
||||||
### Git Branches
|
### Git Branches
|
||||||
|
64
Cargo.lock
generated
64
Cargo.lock
generated
@ -700,12 +700,23 @@ dependencies = [
|
|||||||
"displaydoc",
|
"displaydoc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "camino"
|
||||||
|
version = "1.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c59e92b5a388f549b863a7bea62612c09f24c8393560709a54558a9abdfb3b9c"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-core"
|
name = "candle-core"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
source = "git+https://github.com/huggingface/candle.git#5270224f407502b82fe90bc2622894ce3871b002"
|
source = "git+https://github.com/huggingface/candle.git#5270224f407502b82fe90bc2622894ce3871b002"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"byteorder",
|
"byteorder",
|
||||||
|
"candle-kernels",
|
||||||
|
"cudarc",
|
||||||
"gemm",
|
"gemm",
|
||||||
"half 2.3.1",
|
"half 2.3.1",
|
||||||
"memmap2 0.9.3",
|
"memmap2 0.9.3",
|
||||||
@ -720,6 +731,16 @@ dependencies = [
|
|||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "candle-kernels"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"glob",
|
||||||
|
"rayon",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "candle-nn"
|
name = "candle-nn"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
@ -752,6 +773,29 @@ dependencies = [
|
|||||||
"wav",
|
"wav",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cargo-platform"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ceed8ef69d8518a5dda55c07425450b58a4e1946f4951eab6d7191ee86c2443d"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cargo_metadata"
|
||||||
|
version = "0.18.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037"
|
||||||
|
dependencies = [
|
||||||
|
"camino",
|
||||||
|
"cargo-platform",
|
||||||
|
"semver",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cargo_toml"
|
name = "cargo_toml"
|
||||||
version = "0.18.0"
|
version = "0.18.0"
|
||||||
@ -1163,6 +1207,15 @@ dependencies = [
|
|||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cudarc"
|
||||||
|
version = "0.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9395df0cab995685664e79cc35ad6302bf08fb9c5d82301875a183affe1278b1"
|
||||||
|
dependencies = [
|
||||||
|
"half 2.3.1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling"
|
name = "darling"
|
||||||
version = "0.14.4"
|
version = "0.14.4"
|
||||||
@ -4827,6 +4880,9 @@ name = "semver"
|
|||||||
version = "1.0.18"
|
version = "1.0.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918"
|
checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "seq-macro"
|
name = "seq-macro"
|
||||||
@ -6174,6 +6230,14 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xtask"
|
||||||
|
version = "1.6.0"
|
||||||
|
dependencies = [
|
||||||
|
"cargo_metadata",
|
||||||
|
"clap",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yada"
|
name = "yada"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
@ -16,6 +16,7 @@ members = [
|
|||||||
"json-depth-checker",
|
"json-depth-checker",
|
||||||
"benchmarks",
|
"benchmarks",
|
||||||
"fuzzers",
|
"fuzzers",
|
||||||
|
"xtask",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
|
@ -137,3 +137,6 @@ greek = ["charabia/greek"]
|
|||||||
|
|
||||||
# allow khmer specialized tokenization
|
# allow khmer specialized tokenization
|
||||||
khmer = ["charabia/khmer"]
|
khmer = ["charabia/khmer"]
|
||||||
|
|
||||||
|
# allow CUDA support, see <https://github.com/meilisearch/meilisearch/issues/4306>
|
||||||
|
cuda = ["candle-core/cuda"]
|
||||||
|
@ -70,7 +70,13 @@ impl std::fmt::Debug for Embedder {
|
|||||||
|
|
||||||
impl Embedder {
|
impl Embedder {
|
||||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||||
let device = candle_core::Device::Cpu;
|
let device = match candle_core::Device::cuda_if_available(0) {
|
||||||
|
Ok(device) => device,
|
||||||
|
Err(error) => {
|
||||||
|
log::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error);
|
||||||
|
candle_core::Device::Cpu
|
||||||
|
}
|
||||||
|
};
|
||||||
let repo = match options.revision.clone() {
|
let repo = match options.revision.clone() {
|
||||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||||
None => Repo::model(options.model.clone()),
|
None => Repo::model(options.model.clone()),
|
||||||
|
15
xtask/Cargo.toml
Normal file
15
xtask/Cargo.toml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
[package]
|
||||||
|
name = "xtask"
|
||||||
|
version.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
description = "Workspace automation tool following the xtask pattern <https://github.com/matklad/cargo-xtask>"
|
||||||
|
homepage.workspace = true
|
||||||
|
readme.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cargo_metadata = "0.18.1"
|
||||||
|
clap = { version = "4.4.14", features = ["derive"] }
|
41
xtask/src/main.rs
Normal file
41
xtask/src/main.rs
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
/// List features available in the workspace
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
struct ListFeaturesDeriveArgs {
|
||||||
|
/// Feature to exclude from the list. Repeat the argument to exclude multiple features
|
||||||
|
#[arg(short, long)]
|
||||||
|
exclude_feature: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Utilitary commands
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about)]
|
||||||
|
#[command(name = "cargo xtask")]
|
||||||
|
#[command(bin_name = "cargo xtask")]
|
||||||
|
enum Command {
|
||||||
|
ListFeatures(ListFeaturesDeriveArgs),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let args = Command::parse();
|
||||||
|
match args {
|
||||||
|
Command::ListFeatures(args) => list_features(args),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_features(args: ListFeaturesDeriveArgs) {
|
||||||
|
let exclude_features: HashSet<_> = args.exclude_feature.into_iter().collect();
|
||||||
|
let metadata = cargo_metadata::MetadataCommand::new().no_deps().exec().unwrap();
|
||||||
|
let features: Vec<String> = metadata
|
||||||
|
.packages
|
||||||
|
.iter()
|
||||||
|
.flat_map(|package| package.features.keys())
|
||||||
|
.filter(|feature| !exclude_features.contains(feature.as_str()))
|
||||||
|
.map(|s| s.to_owned())
|
||||||
|
.collect();
|
||||||
|
let features = features.join(" ");
|
||||||
|
println!("{features}")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user