Start of Readme, More Cleanup

This commit is contained in:
2025-04-20 10:57:06 -04:00
parent 25f8cae8cb
commit 9dcd31dd04
8 changed files with 222 additions and 298 deletions

View File

@@ -17,6 +17,9 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/lib/pq"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/pgvector"
"path/filepath"
"strconv"
)
@@ -39,11 +42,13 @@ func preparedStatements() map[string]string {
}
}
// Database is a helper that abstracts access to the database as well as handles migrations.
type Database struct {
pool *pgxpool.Pool
}
func (db *Database) DB(ctx context.Context) (*pgxpool.Conn, error) {
// db gets a references to the connection pool.
func (db *Database) db(ctx context.Context) (*pgxpool.Conn, error) {
conn, err := db.pool.Acquire(ctx)
if err != nil {
return nil, err
@@ -58,6 +63,7 @@ func (db *Database) DB(ctx context.Context) (*pgxpool.Conn, error) {
return conn, nil
}
// UpsertRepo will ensure we have a repository reference for this particular repository path and HEAD reference.
func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bool, error) {
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
@@ -71,7 +77,7 @@ func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bo
return "", false, err
}
conn, err := db.DB(ctx)
conn, err := db.db(ctx)
if err != nil {
return "", false, err
}
@@ -94,8 +100,9 @@ func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bo
return id, false, nil
}
// GetChunk will get a specified chunk from the database.
func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) {
conn, err := db.DB(ctx)
conn, err := db.db(ctx)
if err != nil {
return "", err
}
@@ -112,6 +119,43 @@ func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID stri
return chunk, nil
}
// ClearChunkIndex will clear all embeddings for a particular repo ID.
func (db *Database) ClearChunkIndex(ctx context.Context, repoID string) error {
conn, err := db.db(ctx)
if err != nil {
return err
}
defer conn.Release()
if _, err := conn.Exec(ctx, "clear_chunks_for_repo", repoID); err != nil {
return err
}
return nil
}
// GetVectorStore gets a vector store capable of storing and searching embeddings.
func (db *Database) GetVectorStore(ctx context.Context, embedder embeddings.Embedder) (vectorstores.VectorStore, func(), error) {
conn, err := db.db(ctx)
if err != nil {
return nil, nil, err
}
vectorStore, err := pgvector.New(ctx,
pgvector.WithConn(conn),
pgvector.WithEmbedder(embedder),
pgvector.WithCollectionName("file_chunks"),
)
if err != nil {
return nil, nil, err
}
return vectorStore, func() {
conn.Release()
}, nil
}
// GetChunkContext will get a specified chunk and surrounding context based on some amount of distance.
func (db *Database) GetChunkContext(ctx context.Context, chunkID, distance int, path, repoID string) (string, error) {
minChunk := chunkID - distance
if minChunk < 0 {
@@ -134,6 +178,7 @@ func (db *Database) GetChunkContext(ctx context.Context, chunkID, distance int,
return chunkContext, nil
}
// FromConfig generates a new Database from a passed in configuration. Used for bootstrapping.
func FromConfig(ctx context.Context, cfg *config.Configuration) (*Database, error) {
migFS, err := iofs.New(migrations, "migrations")
if err != nil {
@@ -157,10 +202,12 @@ func FromConfig(ctx context.Context, cfg *config.Configuration) (*Database, erro
return &Database{pool: pool}, nil
}
// FromContext will retrieve the Database from a Context where WrapContext was used to embed the database in the context.
func FromContext(ctx context.Context) *Database {
return ctx.Value(contextKeyDB).(*Database)
}
// WrapContext embeds a Database inside a Context so it can be passed to functions.
func WrapContext(ctx context.Context, cfg *Database) context.Context {
return context.WithValue(ctx, contextKeyDB, cfg)
}

View File

@@ -5,13 +5,14 @@ import (
"ai-code-assistant/pkg/llm"
"context"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores/pgvector"
"log/slog"
"os"
"path/filepath"
"strconv"
)
// Indexer is responsible for crawling the head of a Git repository and generating embeddings so that the most relevant
// chunks of code can be identified based on a given prompt.
type Indexer struct {
repoPath string
chunkSize int
@@ -21,6 +22,7 @@ type Indexer struct {
allowedExtensions []string
}
// New creates a new Indexer.
func New(ctx context.Context, path string, chunkSize int, force bool) *Indexer {
return &Indexer{
repoPath: path,
@@ -34,6 +36,8 @@ func New(ctx context.Context, path string, chunkSize int, force bool) *Indexer {
}
}
// Index will crawl a repository looking for supported files to index and will then index them.
// The files are indexed by the path to the repository and the Git commit hash of the current HEAD.
func (idx *Indexer) Index(ctx context.Context) error {
repoID, didUpdate, err := idx.db.UpsertRepo(ctx, idx.repoPath)
if err != nil {
@@ -46,7 +50,7 @@ func (idx *Indexer) Index(ctx context.Context) error {
return nil
} else if didUpdate && idx.force {
slog.Info("repo already indexed, but force flag was set, cleaning and re-indexing")
if err := idx.cleanIndexForRepo(ctx, repoID); err != nil {
if err := idx.db.ClearChunkIndex(ctx, repoID); err != nil {
return err
}
} else if !didUpdate {
@@ -60,58 +64,14 @@ func (idx *Indexer) Index(ctx context.Context) error {
return nil
}
func (idx *Indexer) cleanIndexForRepo(ctx context.Context, repoID string) error {
conn, err := idx.db.DB(ctx)
if err != nil {
return err
}
defer conn.Release()
if _, err := conn.Exec(ctx, "clear_chunks_for_repo", repoID); err != nil {
return err
}
return nil
}
func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error {
pathFiles, err := os.ReadDir(path)
if err != nil {
return err
}
for _, file := range pathFiles {
filePath := filepath.Join(path, file.Name())
if file.IsDir() {
if err := crawlFiles(ctx, filePath, cb); err != nil {
return err
}
} else {
if err := cb(ctx, filePath); err != nil {
return err
}
}
}
return nil
}
// generateFileChunks crawls the repository looking for supported files and will store the embeddings for those files
// in the database so we can look for relevant chunks later.
func (idx *Indexer) generateFileChunks(ctx context.Context, repoID string) error {
conn, err := idx.db.DB(ctx)
if err != nil {
return err
}
defer conn.Release()
vectorStore, err := pgvector.New(ctx,
pgvector.WithConn(conn),
pgvector.WithEmbedder(idx.llm.Embedder()),
pgvector.WithCollectionName("file_chunks"),
)
vectorStore, closeFunc, err := idx.db.GetVectorStore(ctx, idx.llm.Embedder())
if err != nil {
return err
}
defer closeFunc()
return crawlFiles(ctx, idx.repoPath, func(ctx context.Context, filePath string) error {
chunkID := 0
@@ -154,10 +114,34 @@ func (idx *Indexer) generateFileChunks(ctx context.Context, repoID string) error
})
}
// crawlFiles recursively crawls the repository tree looking for files, when a file is located the callback is called.
func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error {
pathFiles, err := os.ReadDir(path)
if err != nil {
return err
}
for _, file := range pathFiles {
filePath := filepath.Join(path, file.Name())
if file.IsDir() {
if err := crawlFiles(ctx, filePath, cb); err != nil {
return err
}
} else {
if err := cb(ctx, filePath); err != nil {
return err
}
}
}
return nil
}
// chunkFile will take a file and return it in chunks that are suitable size to be embedded.
// This is a very simple algorithm right now, it would be better to use a lexer to identify good parts of the AST to
// split on. We could also implement a reference graph to find the most relevant files based on the relationships
// between files.
// split on (e.g. functions, or groups of functions). We could also implement a reference graph to find the most
// relevant files based on the relationships between files.
func chunkFile(_ context.Context, filePath string, maxBytes int, chunkCb func(chunk []byte, start, end uint64) error) error {
fileBytes, err := os.ReadFile(filePath)
if err != nil {

View File

@@ -6,7 +6,6 @@ import (
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/pgvector"
"strconv"
)
@@ -37,20 +36,11 @@ func NewGetRelevantDocs(db *database.Database, llm *LLM, repoID string, size int
}
func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) ([]*FileChunkID, error) {
conn, err := rd.db.DB(ctx)
if err != nil {
return nil, err
}
defer conn.Release()
vectorStore, err := pgvector.New(ctx,
pgvector.WithConn(conn),
pgvector.WithEmbedder(rd.llm.Embedder()),
pgvector.WithCollectionName("file_chunks"),
)
vectorStore, closeFunc, err := rd.db.GetVectorStore(ctx, rd.llm.Embedder())
if err != nil {
return nil, err
}
defer closeFunc()
retr := vectorstores.ToRetriever(vectorStore, rd.size, vectorstores.WithFilters(map[string]any{"type": "file_chunk", "repo_id": rd.repoID}))
retr.CallbacksHandler = rd.CallbacksHandler