package llm import ( "ai-code-assistant/pkg/database" "context" "fmt" "github.com/cenkalti/backoff/v5" "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" "github.com/tmc/langchaingo/vectorstores/pgvector" "log/slog" "slices" "strconv" "strings" "time" ) type RelevantDocs struct { CallbacksHandler callbacks.Handler db *database.Database llm *LLM repoID string size int } type FileChunkID struct { Name string ChunkID int Start uint64 End uint64 Score float32 Doc *schema.Document } func NewGetRelevantDocs(db *database.Database, llm *LLM, repoID string, size int) *RelevantDocs { return &RelevantDocs{ db: db, llm: llm, repoID: repoID, size: size, } } func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) ([]*FileChunkID, error) { conn, err := rd.db.DB(ctx) if err != nil { return nil, err } defer conn.Release() vectorStore, err := pgvector.New(ctx, pgvector.WithConn(conn), pgvector.WithEmbedder(rd.llm.Embedder()), pgvector.WithCollectionName("file_chunks"), ) if err != nil { return nil, err } retr := vectorstores.ToRetriever(vectorStore, rd.size, vectorstores.WithFilters(map[string]any{"type": "file_chunk", "repo_id": rd.repoID})) retr.CallbacksHandler = rd.CallbacksHandler docs, err := retr.GetRelevantDocuments(ctx, query) if err != nil { return nil, err } var chunks []*FileChunkID for _, doc := range docs { chunk := &FileChunkID{ Score: doc.Score, } if filePath, ok := doc.Metadata["file_path"].(string); ok { chunk.Name = filePath } if chunkID, ok := doc.Metadata["chunk_id"].(string); ok { id, _ := strconv.ParseInt(chunkID, 10, 64) chunk.ChunkID = int(id) } if start, ok := doc.Metadata["start"].(string); ok { chunk.Start, _ = strconv.ParseUint(start, 10, 64) } if end, ok := doc.Metadata["end"].(string); ok { chunk.End, _ = strconv.ParseUint(end, 10, 64) } chunk.Doc = &doc chunks = append(chunks, chunk) } return chunks, nil } func (rd *RelevantDocs) RankChunks(ctx context.Context, query string, chunks []*FileChunkID) error { var didErr error slices.SortFunc(chunks, func(a, b *FileChunkID) int { if didErr != nil { return 0 } retr, err := rd.CompareChunks(ctx, query, a, b) if err != nil { didErr = err } return retr }) return didErr } func (rd *RelevantDocs) CompareChunks(ctx context.Context, query string, chunk1, chunk2 *FileChunkID) (int, error) { slog.Info("comparing chunks", "chunk_1_name", chunk1.Name, "chunk_1_id", chunk1.ChunkID, "chunk_2_name", chunk2.Name, "chunk_2_id", chunk2.ChunkID) prompt := `Given the following two pieces of code pick the most relevant chunk to the task described below. Reply as a json object in the format {"chunk_id": ""}. Only reply in JSON. Do not include any explanation or code.` prompt += "\n\n" + query + "\n\n" // Now that we have candidates we need to compare them against each other to find the most appropriate place to // inject them. prompt += "-- chunk_id: chunk_1 --\n" prompt += chunk1.Doc.PageContent prompt += "-- chunk_id: chunk_2 --\n" prompt += chunk2.Doc.PageContent op := func() (int, error) { rsp, err := rd.llm.CodePrompt(ctx, prompt) if err != nil { return 0, err } if strings.Contains(rsp, "chunk_1") { return -1, nil } else if strings.Contains(rsp, "chunk_2") { return 1, nil } return 0, fmt.Errorf("compare response didn't contain a chunk id: %s", rsp) } return backoff.Retry(ctx, op, backoff.WithBackOff(backoff.NewConstantBackOff(10*time.Millisecond)), backoff.WithMaxTries(1)) }