diff --git a/scripts/ai/main.py b/scripts/ai/main.py new file mode 100644 index 0000000000..4e1d443574 --- /dev/null +++ b/scripts/ai/main.py @@ -0,0 +1,17 @@ +from typing import Optional, List + +from fastapi import FastAPI +from pydantic import BaseModel +from sentence_transformers import SentenceTransformer + +class Texts(BaseModel): + texts: List + +app = FastAPI() + +model = SentenceTransformer("all-MiniLM-L6-v2") + +@app.post("/embedding/") +async def embedding(texts: Texts): + data = model.encode(list(texts), convert_to_numpy=True).tolist() + return {"embedding": data} diff --git a/src/main/frontend/modules/ai/database.cljs b/src/main/frontend/modules/ai/database.cljs new file mode 100644 index 0000000000..6342658aca --- /dev/null +++ b/src/main/frontend/modules/ai/database.cljs @@ -0,0 +1,76 @@ +(ns frontend.modules.ai.database + (:require [frontend.util :as util] + [frontend.db :as db] + [frontend.db.model :as db-model] + [frontend.modules.ai.embedding.local :as embedding] + [cljs-bean.core :as bean] + [promesa.core :as p])) + +(def api "http://localhost:6333/") + +(defn- fetch + [uri opts] + (util/fetch uri opts p/resolved p/rejected)) + +(defn create-collection! + [graph-id] + ;; FIXME: generate uuid for each graph if not exists + (let [graph-id (db/get-short-repo-name (frontend.state/get-current-repo))] + (fetch (str api "collections/" graph-id) + {:method "PUT" + :headers {:Content-Type "application/json"} + :body (js/JSON.stringify + (bean/->js {:vectors {:size 384 ; all-MiniLM-L6-v2 + :distance "Dot"}}))}))) + +(defn get-collection + [graph-id] + (fetch (str api "collections/" graph-id) + {:method "GET" + :headers {:Content-Type "application/json"}})) + +(defn- get-blocks + [] + (->> (db-model/get-all-block-contents) + (map #(select-keys % [:block/uuid :block/content])))) + +(defn- blocks->points + [blocks] + (p/all + (map (fn [block] + (p/let [content (:block/content block) + result (embedding/sentence-transformer content)] + {:id (str (:block/uuid block)) + :vector result + :payload block})) blocks))) + +(defn add-points! + [graph-id] + (let [blocks (partition-all 100 (get-blocks))] + (doseq [segment blocks] + (p/let [points (blocks->points segment)] + (fetch (str api "collections/" graph-id "/points?wait=true") + {:method "PUT" + :headers {:Content-Type "application/json"} + :body (js/JSON.stringify + (bean/->js {:points points}))}))))) + +(defn get-top-k + [graph-id q {:keys [top] + :or {top 5}}] + (p/let [vector (embedding/sentence-transformer q)] + (fetch (str api "collections/" graph-id "/points/search") + {:method "POST" + :headers {:Content-Type "application/json"} + :body (js/JSON.stringify + (bean/->js {:vector vector + :top top}))}))) + + +(comment + (p/let [result (get-top-k "docs" "new to logseq" {})] + (doseq [{:keys [id]} (:result result)] + (let [block (-> (db/pull [:block/uuid (uuid id)]) + (select-keys [:block/content :block/page :block/uuid]))] + (prn block)))) + ) diff --git a/src/main/frontend/modules/ai/embedding/local.cljs b/src/main/frontend/modules/ai/embedding/local.cljs new file mode 100644 index 0000000000..efca5eb898 --- /dev/null +++ b/src/main/frontend/modules/ai/embedding/local.cljs @@ -0,0 +1,20 @@ +(ns frontend.modules.ai.embedding.local + (:require [frontend.util :as util] + [cljs-bean.core :as bean] + [promesa.core :as p])) + +;; TODO: only for playground, this should be supported by plugins +(defn sentence-transformer + [text] + (util/fetch "http://127.0.0.1:8000/embedding/" + {:method "POST" + :headers {:Content-Type "application/json"} + :body (js/JSON.stringify + (bean/->js {:texts [text]}))} + (fn [result] + (p/resolved (first (:embedding result)))) + (fn [failed-resp] + (prn "sentence-transformer embedding failed: " + {:text text + :failed failed-resp}) + (p/rejected failed-resp))))