package indexer import ( "ai-code-assistant/pkg/database" "ai-code-assistant/pkg/llm" "context" "github.com/tmc/langchaingo/schema" "log/slog" "os" "path/filepath" "strconv" ) // Indexer is responsible for crawling the head of a Git repository and generating embeddings so that the most relevant // chunks of code can be identified based on a given prompt. type Indexer struct { repoPath string chunkSize int force bool db *database.Database llm *llm.LLM allowedExtensions []string } // New creates a new Indexer. func New(ctx context.Context, path string, chunkSize int, force bool) *Indexer { return &Indexer{ repoPath: path, chunkSize: chunkSize, force: force, db: database.FromContext(ctx), llm: llm.FromContext(ctx), allowedExtensions: []string{ ".go", }, } } // Index will crawl a repository looking for supported files to index and will then index them. // The files are indexed by the path to the repository and the Git commit hash of the current HEAD. func (idx *Indexer) Index(ctx context.Context) error { repoID, didUpdate, err := idx.db.UpsertRepo(ctx, idx.repoPath) if err != nil { return err } if didUpdate && !idx.force { slog.Info("repo already indexed, skipping") return nil } else if didUpdate && idx.force { slog.Info("repo already indexed, but force flag was set, cleaning and re-indexing") if err := idx.db.ClearChunkIndex(ctx, repoID); err != nil { return err } } else if !didUpdate { slog.Info("indexing new repository", "path", idx.repoPath, "repo_id", repoID) } if err := idx.generateFileChunks(ctx, repoID); err != nil { return err } return nil } // generateFileChunks crawls the repository looking for supported files and will store the embeddings for those files // in the database so we can look for relevant chunks later. func (idx *Indexer) generateFileChunks(ctx context.Context, repoID string) error { vectorStore, closeFunc, err := idx.db.GetVectorStore(ctx, idx.llm.Embedder()) if err != nil { return err } defer closeFunc() return crawlFiles(ctx, idx.repoPath, func(ctx context.Context, filePath string) error { chunkID := 0 return chunkFile(ctx, filePath, idx.chunkSize, func(chunk []byte, start, end uint64) error { shouldIndex := false for _, ext := range idx.allowedExtensions { if filepath.Ext(filePath) == ext { shouldIndex = true break } } if !shouldIndex { return nil } slog.Info("indexing file", "chunk_id", chunkID, "chunk_size", len(chunk), "file_name", filePath) docs := []schema.Document{{ PageContent: string(chunk), Metadata: map[string]any{ "type": "file_chunk", "file_path": filePath, "chunk_id": strconv.FormatInt(int64(chunkID), 10), "start": strconv.FormatUint(start, 10), "end": strconv.FormatUint(end, 10), "repo_id": repoID, }, }} if _, err := vectorStore.AddDocuments(ctx, docs); err != nil { return err } chunkID++ return nil }) }) } // crawlFiles recursively crawls the repository tree looking for files, when a file is located the callback is called. func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error { pathFiles, err := os.ReadDir(path) if err != nil { return err } for _, file := range pathFiles { filePath := filepath.Join(path, file.Name()) if file.IsDir() { if err := crawlFiles(ctx, filePath, cb); err != nil { return err } } else { if err := cb(ctx, filePath); err != nil { return err } } } return nil } // chunkFile will take a file and return it in chunks that are suitable size to be embedded. // This is a very simple algorithm right now, it would be better to use a lexer to identify good parts of the AST to // split on (e.g. functions, or groups of functions). We could also implement a reference graph to find the most // relevant files based on the relationships between files. func chunkFile(_ context.Context, filePath string, maxBytes int, chunkCb func(chunk []byte, start, end uint64) error) error { fileBytes, err := os.ReadFile(filePath) if err != nil { return err } pos := 0 for pos < len(fileBytes) { nextChunkSize := maxBytes if pos+maxBytes > len(fileBytes) { nextChunkSize = len(fileBytes) - pos } if err := chunkCb(fileBytes[pos:pos+nextChunkSize], uint64(pos), uint64(pos+nextChunkSize)); err != nil { return err } pos += maxBytes } return nil }