🚧 basic sqlite operations implementation
This commit is contained in:
50
sync/Cargo.lock
generated
50
sync/Cargo.lock
generated
@@ -279,6 +279,27 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "5.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
|
||||
dependencies = [
|
||||
"dirs-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs-sys"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.8"
|
||||
@@ -619,6 +640,17 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.27.0"
|
||||
@@ -795,6 +827,12 @@ version = "1.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||
|
||||
[[package]]
|
||||
name = "option-ext"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
@@ -1001,6 +1039,17 @@ dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_users"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"libredox",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.3"
|
||||
@@ -1257,6 +1306,7 @@ dependencies = [
|
||||
name = "sync"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
"hex-literal",
|
||||
"homedir",
|
||||
"ignore",
|
||||
|
||||
@@ -13,6 +13,7 @@ crate-type = ["cdylib"]
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
dirs = "5.0.1"
|
||||
hex-literal = "0.4.1"
|
||||
homedir = "0.2.1"
|
||||
ignore = "0.4.20"
|
||||
|
||||
@@ -1,8 +1,27 @@
|
||||
use rand::Rng;
|
||||
use rusqlite::{Connection, Result};
|
||||
use ndarray::{Array1, Array2};
|
||||
use rusqlite::Connection;
|
||||
use std::fs;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Chunk {
|
||||
fn get_top_n(v: Vec<f32>, vectors: Vec<Vec<f32>>, d: usize, top_n: usize) -> Vec<usize> {
|
||||
let n = vectors.len();
|
||||
let a: Array2<f32> = Array2::from_shape_fn((n, d), |(i, j)| vectors[i][j]);
|
||||
let b: Array1<f32> = Array1::from_shape_fn(d, |i| v[i]);
|
||||
|
||||
let result = a.dot(&b);
|
||||
let mut indexed_result: Vec<(usize, &f32)> = result.iter().enumerate().collect();
|
||||
indexed_result.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
|
||||
let top_n_indices: Vec<usize> = indexed_result
|
||||
.into_iter()
|
||||
.map(|(index, _value)| index)
|
||||
.take(top_n)
|
||||
.collect();
|
||||
|
||||
return top_n_indices;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Chunk {
|
||||
id: i32,
|
||||
content: String,
|
||||
embedding: String,
|
||||
@@ -23,17 +42,161 @@ pub fn text_to_embedding(text: String) -> Vec<f32> {
|
||||
.collect::<Vec<f32>>();
|
||||
}
|
||||
|
||||
fn rand_embedding(n: i32) -> Vec<f32> {
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..n).map(|_| rng.gen()).collect()
|
||||
// Database interactions
|
||||
fn get_conn() -> Connection {
|
||||
let path = dirs::home_dir()
|
||||
.unwrap()
|
||||
.join(".continue")
|
||||
.join("index")
|
||||
.join("sync.db");
|
||||
fs::create_dir_all(path.parent().unwrap()).unwrap();
|
||||
return Connection::open(path).unwrap();
|
||||
}
|
||||
|
||||
pub fn create_database() {
|
||||
let conn = get_conn();
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
embedding TEXT NOT NULL
|
||||
)",
|
||||
(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY,
|
||||
chunk_id TEXT NOT NULL,
|
||||
tag TEXT NOT NULL
|
||||
)",
|
||||
(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn add_chunk(hash: String, content: String, tags: Vec<String>, embedding: Vec<f32>) {
|
||||
let conn = get_conn();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO chunks (id, content, embedding) VALUES (?1, ?2, ?3)",
|
||||
(&hash, &content, &embedding_to_text(embedding)),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
for tag in tags {
|
||||
conn.execute(
|
||||
"INSERT INTO tags (chunk_id, tag) VALUES (?1, ?2)",
|
||||
(&hash, &tag),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_chunk(hash: String) {
|
||||
let conn = get_conn();
|
||||
|
||||
conn.execute("DELETE FROM chunks WHERE id=?1", (&hash,))
|
||||
.unwrap();
|
||||
|
||||
conn.execute("DELETE FROM tags WHERE chunk_id=?1", (&hash,))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn add_tag(hash: String, tag: String) {
|
||||
let conn = get_conn();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO tags (chunk_id, tag) VALUES (?1, ?2)",
|
||||
(&hash, &tag),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn remove_tag(hash: String, tag: String) {
|
||||
let conn = get_conn();
|
||||
|
||||
conn.execute(
|
||||
"DELETE FROM tags WHERE chunk_id=?1 AND tag=?2",
|
||||
(&hash, &tag),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn retrieve(n: usize, tags: Vec<String>, v: Vec<f32>) -> Vec<Chunk> {
|
||||
let conn = get_conn();
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare(&format!(
|
||||
"
|
||||
SELECT * FROM chunks
|
||||
WHERE id IN (
|
||||
SELECT chunk_id
|
||||
FROM tags
|
||||
WHERE tag IN (?1)
|
||||
)",
|
||||
))
|
||||
.unwrap();
|
||||
let chunk_rows = stmt
|
||||
.query_map((tags.join(", "),), |row| {
|
||||
Ok(Chunk {
|
||||
id: row.get(0)?,
|
||||
content: row.get(1)?,
|
||||
embedding: row.get(2)?,
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut vectors: Vec<Vec<f32>> = Vec::new();
|
||||
for chunk in chunk_rows {
|
||||
let chunk = chunk.unwrap();
|
||||
chunks.push(chunk.clone());
|
||||
let vector = text_to_embedding(chunk.embedding);
|
||||
vectors.push(vector);
|
||||
}
|
||||
|
||||
let top_n_indices = get_top_n(v, vectors, 384, n);
|
||||
return chunks
|
||||
.iter()
|
||||
.cloned()
|
||||
.enumerate()
|
||||
.filter(|(index, _chunk)| top_n_indices.contains(index))
|
||||
.map(|(_index, chunk)| chunk)
|
||||
.collect::<Vec<Chunk>>();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rusqlite::{Connection, Result};
|
||||
use ndarray::{Array1, Array2};
|
||||
use rand::Rng;
|
||||
use rusqlite::Connection;
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_embedding(n: i32) -> Vec<f32> {
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..n).map(|_| rng.gen()).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_database() {
|
||||
create_database();
|
||||
|
||||
let conn = get_conn();
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
.unwrap();
|
||||
let mut table_names = stmt
|
||||
.query_map([], |row| row.get::<usize, String>(0))
|
||||
.unwrap();
|
||||
|
||||
assert!(table_names.any(|table_name| table_name.unwrap().eq("chunks")));
|
||||
assert!(table_names.any(|table_name| table_name.unwrap().eq("tags")))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn benchmark_load_vectors() {
|
||||
let conn = Connection::open("sync.db").unwrap();
|
||||
@@ -50,7 +213,7 @@ mod tests {
|
||||
|
||||
let time = Instant::now();
|
||||
|
||||
for i in 0..10_000 {
|
||||
for _ in 0..10_000 {
|
||||
let chunk = Chunk {
|
||||
id: 0,
|
||||
content: "Test content".to_string(),
|
||||
@@ -84,10 +247,35 @@ mod tests {
|
||||
let mut i = 0;
|
||||
for chunk in chunk_iter {
|
||||
i += 1;
|
||||
let embedding = text_to_embedding(chunk.unwrap().embedding);
|
||||
let _ = text_to_embedding(chunk.unwrap().embedding);
|
||||
}
|
||||
|
||||
println!("Found {} chunks", i);
|
||||
println!("To convert took: {:.2?}", time.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn benchmark_ndarray() {
|
||||
let mut rng = rand::thread_rng();
|
||||
let n = 10_000;
|
||||
let d = 384;
|
||||
let a: Array2<f32> = Array2::from_shape_fn((n, d), |_| rng.gen::<f32>());
|
||||
let b: Array1<f32> = Array1::from_shape_fn(d, |_| rng.gen::<f32>());
|
||||
|
||||
let time = Instant::now();
|
||||
|
||||
let result = a.dot(&b);
|
||||
let mut indexed_result: Vec<(usize, &f32)> = result.iter().enumerate().collect();
|
||||
indexed_result.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
|
||||
let top_n = 50;
|
||||
let _: Vec<usize> = indexed_result
|
||||
.into_iter()
|
||||
.map(|(index, _value)| index)
|
||||
.take(top_n)
|
||||
.collect();
|
||||
|
||||
let elapsed = time.elapsed();
|
||||
println!("Elapsed time: {:.2?}", elapsed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::path::Path;
|
||||
mod db;
|
||||
mod gitignore;
|
||||
mod similarity;
|
||||
mod sync;
|
||||
mod utils;
|
||||
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::{Array1, Array2};
|
||||
use rand::Rng;
|
||||
use std::time::Instant;
|
||||
#[test]
|
||||
fn test_ndarray() {
|
||||
let mut rng = rand::thread_rng();
|
||||
let n = 10_000;
|
||||
let d = 384;
|
||||
let a: Array2<f32> = Array2::from_shape_fn((n, d), |_| rng.gen::<f32>());
|
||||
let b: Array1<f32> = Array1::from_shape_fn(d, |_| rng.gen::<f32>());
|
||||
|
||||
let time = Instant::now();
|
||||
|
||||
let result = a.dot(&b);
|
||||
let mut indexed_result: Vec<(usize, &f32)> = result.iter().enumerate().collect();
|
||||
indexed_result.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
|
||||
let top_n = 50;
|
||||
let top_n_indices: Vec<usize> = indexed_result
|
||||
.into_iter()
|
||||
.map(|(index, _value)| index)
|
||||
.take(top_n)
|
||||
.collect();
|
||||
|
||||
let elapsed = time.elapsed();
|
||||
println!("Elapsed time: {:.2?}", elapsed);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user