🚧 basic sqlite operations implementation

This commit is contained in:
Nate Sesti
2024-01-06 21:27:08 -08:00
parent 7d5782c184
commit 53b7f4b556
5 changed files with 249 additions and 42 deletions

50
sync/Cargo.lock generated
View File

@@ -279,6 +279,27 @@ dependencies = [
"subtle",
]
[[package]]
name = "dirs"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.48.0",
]
[[package]]
name = "errno"
version = "0.3.8"
@@ -619,6 +640,17 @@ dependencies = [
"winapi",
]
[[package]]
name = "libredox"
version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8"
dependencies = [
"bitflags 2.4.1",
"libc",
"redox_syscall",
]
[[package]]
name = "libsqlite3-sys"
version = "0.27.0"
@@ -795,6 +827,12 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "option-ext"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "parking_lot"
version = "0.12.1"
@@ -1001,6 +1039,17 @@ dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "redox_users"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4"
dependencies = [
"getrandom",
"libredox",
"thiserror",
]
[[package]]
name = "regex-automata"
version = "0.4.3"
@@ -1257,6 +1306,7 @@ dependencies = [
name = "sync"
version = "0.1.0"
dependencies = [
"dirs",
"hex-literal",
"homedir",
"ignore",

View File

@@ -13,6 +13,7 @@ crate-type = ["cdylib"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
dirs = "5.0.1"
hex-literal = "0.4.1"
homedir = "0.2.1"
ignore = "0.4.20"

View File

@@ -1,8 +1,27 @@
use rand::Rng;
use rusqlite::{Connection, Result};
use ndarray::{Array1, Array2};
use rusqlite::Connection;
use std::fs;
#[derive(Debug)]
struct Chunk {
fn get_top_n(v: Vec<f32>, vectors: Vec<Vec<f32>>, d: usize, top_n: usize) -> Vec<usize> {
let n = vectors.len();
let a: Array2<f32> = Array2::from_shape_fn((n, d), |(i, j)| vectors[i][j]);
let b: Array1<f32> = Array1::from_shape_fn(d, |i| v[i]);
let result = a.dot(&b);
let mut indexed_result: Vec<(usize, &f32)> = result.iter().enumerate().collect();
indexed_result.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top_n_indices: Vec<usize> = indexed_result
.into_iter()
.map(|(index, _value)| index)
.take(top_n)
.collect();
return top_n_indices;
}
#[derive(Debug, Clone)]
pub struct Chunk {
id: i32,
content: String,
embedding: String,
@@ -23,17 +42,161 @@ pub fn text_to_embedding(text: String) -> Vec<f32> {
.collect::<Vec<f32>>();
}
fn rand_embedding(n: i32) -> Vec<f32> {
let mut rng = rand::thread_rng();
(0..n).map(|_| rng.gen()).collect()
// Database interactions
fn get_conn() -> Connection {
let path = dirs::home_dir()
.unwrap()
.join(".continue")
.join("index")
.join("sync.db");
fs::create_dir_all(path.parent().unwrap()).unwrap();
return Connection::open(path).unwrap();
}
pub fn create_database() {
let conn = get_conn();
conn.execute(
"CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
embedding TEXT NOT NULL
)",
(),
)
.unwrap();
conn.execute(
"CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY,
chunk_id TEXT NOT NULL,
tag TEXT NOT NULL
)",
(),
)
.unwrap();
}
pub fn add_chunk(hash: String, content: String, tags: Vec<String>, embedding: Vec<f32>) {
let conn = get_conn();
conn.execute(
"INSERT INTO chunks (id, content, embedding) VALUES (?1, ?2, ?3)",
(&hash, &content, &embedding_to_text(embedding)),
)
.unwrap();
for tag in tags {
conn.execute(
"INSERT INTO tags (chunk_id, tag) VALUES (?1, ?2)",
(&hash, &tag),
)
.unwrap();
}
}
pub fn remove_chunk(hash: String) {
let conn = get_conn();
conn.execute("DELETE FROM chunks WHERE id=?1", (&hash,))
.unwrap();
conn.execute("DELETE FROM tags WHERE chunk_id=?1", (&hash,))
.unwrap();
}
pub fn add_tag(hash: String, tag: String) {
let conn = get_conn();
conn.execute(
"INSERT INTO tags (chunk_id, tag) VALUES (?1, ?2)",
(&hash, &tag),
)
.unwrap();
}
pub fn remove_tag(hash: String, tag: String) {
let conn = get_conn();
conn.execute(
"DELETE FROM tags WHERE chunk_id=?1 AND tag=?2",
(&hash, &tag),
)
.unwrap();
}
pub fn retrieve(n: usize, tags: Vec<String>, v: Vec<f32>) -> Vec<Chunk> {
let conn = get_conn();
let mut stmt = conn
.prepare(&format!(
"
SELECT * FROM chunks
WHERE id IN (
SELECT chunk_id
FROM tags
WHERE tag IN (?1)
)",
))
.unwrap();
let chunk_rows = stmt
.query_map((tags.join(", "),), |row| {
Ok(Chunk {
id: row.get(0)?,
content: row.get(1)?,
embedding: row.get(2)?,
})
})
.unwrap();
let mut chunks = Vec::new();
let mut vectors: Vec<Vec<f32>> = Vec::new();
for chunk in chunk_rows {
let chunk = chunk.unwrap();
chunks.push(chunk.clone());
let vector = text_to_embedding(chunk.embedding);
vectors.push(vector);
}
let top_n_indices = get_top_n(v, vectors, 384, n);
return chunks
.iter()
.cloned()
.enumerate()
.filter(|(index, _chunk)| top_n_indices.contains(index))
.map(|(_index, chunk)| chunk)
.collect::<Vec<Chunk>>();
}
#[cfg(test)]
mod tests {
use super::*;
use rusqlite::{Connection, Result};
use ndarray::{Array1, Array2};
use rand::Rng;
use rusqlite::Connection;
use std::time::Instant;
fn rand_embedding(n: i32) -> Vec<f32> {
let mut rng = rand::thread_rng();
(0..n).map(|_| rng.gen()).collect()
}
#[test]
fn test_create_database() {
create_database();
let conn = get_conn();
let mut stmt = conn
.prepare("SELECT name FROM sqlite_master WHERE type='table'")
.unwrap();
let mut table_names = stmt
.query_map([], |row| row.get::<usize, String>(0))
.unwrap();
assert!(table_names.any(|table_name| table_name.unwrap().eq("chunks")));
assert!(table_names.any(|table_name| table_name.unwrap().eq("tags")))
}
#[test]
fn benchmark_load_vectors() {
let conn = Connection::open("sync.db").unwrap();
@@ -50,7 +213,7 @@ mod tests {
let time = Instant::now();
for i in 0..10_000 {
for _ in 0..10_000 {
let chunk = Chunk {
id: 0,
content: "Test content".to_string(),
@@ -84,10 +247,35 @@ mod tests {
let mut i = 0;
for chunk in chunk_iter {
i += 1;
let embedding = text_to_embedding(chunk.unwrap().embedding);
let _ = text_to_embedding(chunk.unwrap().embedding);
}
println!("Found {} chunks", i);
println!("To convert took: {:.2?}", time.elapsed());
}
#[test]
fn benchmark_ndarray() {
let mut rng = rand::thread_rng();
let n = 10_000;
let d = 384;
let a: Array2<f32> = Array2::from_shape_fn((n, d), |_| rng.gen::<f32>());
let b: Array1<f32> = Array1::from_shape_fn(d, |_| rng.gen::<f32>());
let time = Instant::now();
let result = a.dot(&b);
let mut indexed_result: Vec<(usize, &f32)> = result.iter().enumerate().collect();
indexed_result.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top_n = 50;
let _: Vec<usize> = indexed_result
.into_iter()
.map(|(index, _value)| index)
.take(top_n)
.collect();
let elapsed = time.elapsed();
println!("Elapsed time: {:.2?}", elapsed);
}
}

View File

@@ -1,7 +1,6 @@
use std::path::Path;
mod db;
mod gitignore;
mod similarity;
mod sync;
mod utils;

View File

@@ -1,31 +0,0 @@
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
use rand::Rng;
use std::time::Instant;
#[test]
fn test_ndarray() {
let mut rng = rand::thread_rng();
let n = 10_000;
let d = 384;
let a: Array2<f32> = Array2::from_shape_fn((n, d), |_| rng.gen::<f32>());
let b: Array1<f32> = Array1::from_shape_fn(d, |_| rng.gen::<f32>());
let time = Instant::now();
let result = a.dot(&b);
let mut indexed_result: Vec<(usize, &f32)> = result.iter().enumerate().collect();
indexed_result.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top_n = 50;
let top_n_indices: Vec<usize> = indexed_result
.into_iter()
.map(|(index, _value)| index)
.take(top_n)
.collect();
let elapsed = time.elapsed();
println!("Elapsed time: {:.2?}", elapsed);
}
}