Files
ai-code-assistant/pkg/indexer/indexer.go

167 lines
4.5 KiB
Go

package indexer
import (
"ai-code-assistant/pkg/database"
"ai-code-assistant/pkg/llm"
"context"
"github.com/tmc/langchaingo/schema"
"log/slog"
"os"
"path/filepath"
"strconv"
)
// Indexer is responsible for crawling the head of a Git repository and generating embeddings so that the most relevant
// chunks of code can be identified based on a given prompt.
type Indexer struct {
repoPath string
chunkSize int
force bool
db *database.Database
llm *llm.LLM
allowedExtensions []string
}
// New creates a new Indexer.
func New(ctx context.Context, path string, chunkSize int, force bool) *Indexer {
return &Indexer{
repoPath: path,
chunkSize: chunkSize,
force: force,
db: database.FromContext(ctx),
llm: llm.FromContext(ctx),
allowedExtensions: []string{
".go",
},
}
}
// Index will crawl a repository looking for supported files to index and will then index them.
// The files are indexed by the path to the repository and the Git commit hash of the current HEAD.
func (idx *Indexer) Index(ctx context.Context) error {
repoID, didUpdate, err := idx.db.UpsertRepo(ctx, idx.repoPath)
if err != nil {
return err
}
if didUpdate && !idx.force {
slog.Info("repo already indexed, skipping")
return nil
} else if didUpdate && idx.force {
slog.Info("repo already indexed, but force flag was set, cleaning and re-indexing")
if err := idx.db.ClearChunkIndex(ctx, repoID); err != nil {
return err
}
} else if !didUpdate {
slog.Info("indexing new repository", "path", idx.repoPath, "repo_id", repoID)
}
if err := idx.generateFileChunks(ctx, repoID); err != nil {
return err
}
return nil
}
// generateFileChunks crawls the repository looking for supported files and will store the embeddings for those files
// in the database so we can look for relevant chunks later.
func (idx *Indexer) generateFileChunks(ctx context.Context, repoID string) error {
vectorStore, closeFunc, err := idx.db.GetVectorStore(ctx, idx.llm.Embedder())
if err != nil {
return err
}
defer closeFunc()
return crawlFiles(ctx, idx.repoPath, func(ctx context.Context, filePath string) error {
chunkID := 0
return chunkFile(ctx, filePath, idx.chunkSize, func(chunk []byte, start, end uint64) error {
shouldIndex := false
for _, ext := range idx.allowedExtensions {
if filepath.Ext(filePath) == ext {
shouldIndex = true
break
}
}
if !shouldIndex {
return nil
}
slog.Info("indexing file", "chunk_id", chunkID, "chunk_size", len(chunk), "file_name", filePath)
docs := []schema.Document{{
PageContent: string(chunk),
Metadata: map[string]any{
"type": "file_chunk",
"file_path": filePath,
"chunk_id": strconv.FormatInt(int64(chunkID), 10),
"start": strconv.FormatUint(start, 10),
"end": strconv.FormatUint(end, 10),
"repo_id": repoID,
},
}}
if _, err := vectorStore.AddDocuments(ctx, docs); err != nil {
return err
}
chunkID++
return nil
})
})
}
// crawlFiles recursively crawls the repository tree looking for files, when a file is located the callback is called.
func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error {
pathFiles, err := os.ReadDir(path)
if err != nil {
return err
}
for _, file := range pathFiles {
filePath := filepath.Join(path, file.Name())
if file.IsDir() {
if err := crawlFiles(ctx, filePath, cb); err != nil {
return err
}
} else {
if err := cb(ctx, filePath); err != nil {
return err
}
}
}
return nil
}
// chunkFile will take a file and return it in chunks that are suitable size to be embedded.
// This is a very simple algorithm right now, it would be better to use a lexer to identify good parts of the AST to
// split on (e.g. functions, or groups of functions). We could also implement a reference graph to find the most
// relevant files based on the relationships between files.
func chunkFile(_ context.Context, filePath string, maxBytes int, chunkCb func(chunk []byte, start, end uint64) error) error {
fileBytes, err := os.ReadFile(filePath)
if err != nil {
return err
}
pos := 0
for pos < len(fileBytes) {
nextChunkSize := maxBytes
if pos+maxBytes > len(fileBytes) {
nextChunkSize = len(fileBytes) - pos
}
if err := chunkCb(fileBytes[pos:pos+nextChunkSize], uint64(pos), uint64(pos+nextChunkSize)); err != nil {
return err
}
pos += maxBytes
}
return nil
}