diff --git a/internal/server/scanner.go b/internal/server/scanner.go index 24ccbbaa..bfe46af1 100644 --- a/internal/server/scanner.go +++ b/internal/server/scanner.go @@ -447,7 +447,12 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { } } if sw.output == outputIDs { - wr.WriteString(jsonString(opts.id)) + if opts.distOutput || opts.distance > 0 { + wr.WriteString(`{"id":` + jsonString(opts.id) + + `,"distance":` + strconv.FormatFloat(opts.distance, 'f', -1, 64) + "}") + } else { + wr.WriteString(jsonString(opts.id)) + } } else { wr.WriteString(`{"id":` + jsonString(opts.id)) switch sw.output { @@ -476,7 +481,12 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { vals := make([]resp.Value, 1, 3) vals[0] = resp.StringValue(opts.id) if sw.output == outputIDs { - sw.values = append(sw.values, vals[0]) + if opts.distOutput || opts.distance > 0 { + vals = append(vals, resp.FloatValue(opts.distance)) + sw.values = append(sw.values, resp.ArrayValue(vals)) + } else { + sw.values = append(sw.values, vals[0]) + } } else { switch sw.output { case outputObjects: diff --git a/tests/keys_search_test.go b/tests/keys_search_test.go index 0415fcf1..2a164ed9 100644 --- a/tests/keys_search_test.go +++ b/tests/keys_search_test.go @@ -46,9 +46,46 @@ func keys_KNN_basic_test(mc *mockServer) error { {"NEARBY", "mykey", "LIMIT", 10, "POINTS", "POINT", 20, 20}, { "[0 [[2 [19 19]] [3 [12 19]] [5 [33 21]] [1 [5 5]] [4 [-5 5]] [6 [52 13]]]]"}, {"NEARBY", "mykey", "LIMIT", 10, "IDS", "POINT", 20, 20, 4000000}, {"[0 [2 3 5 1 4 6]]"}, - {"NEARBY", "mykey", "LIMIT", 10, "IDS", "POINT", 20, 20, 1500000}, {"[0 [2 3 5]]"}, + {"NEARBY", "mykey", "LIMIT", 10, "DISTANCE", "IDS", "POINT", 20, 20, 1500000}, {"[0 [[2 152808.67164037024] [3 895945.1409106688] [5 1448929.5916252395]]]"}, {"NEARBY", "mykey", "LIMIT", 10, "DISTANCE", "POINT", 52, 13, 100}, {`[0 [[6 {"type":"Point","coordinates":[13,52]} 0]]]`}, {"NEARBY", "mykey", "LIMIT", 10, "POINT", 52.1, 13.1, 100000}, {`[0 [[6 {"type":"Point","coordinates":[13,52]}]]]`}, + {"OUTPUT", "json"}, {func(res string) bool { return gjson.Get(res, "ok").Bool() }}, + {"NEARBY", "mykey", "LIMIT", 10, "DISTANCE", "IDS", "POINT", 20, 20, 1500000}, { + func(res string) error { + if !gjson.Get(res, "ok").Bool() { + return errors.New("not ok") + } + if gjson.Get(res, "ids.#").Int() != 3 { + return fmt.Errorf("expected '%d' objects, got '%d'", 3, gjson.Get(res, "ids.#").Int()) + } + if gjson.Get(res, "ids.#.distance|#").Int() != 3 { + return fmt.Errorf("expected '%d' distances, got '%d'", 3, gjson.Get(res, "ids.#.distance|#").Int()) + } + + for i, d := range gjson.Get(res, "ids.#.distance").Array() { + if d.Float() <= 0 { + return fmt.Errorf("expected all distances to be greater than 0: (%d, %f)", i, d.Float()) + } + } + + return nil + }, + }, + {"NEARBY", "mykey", "LIMIT", 10, "DISTANCE", "IDS", "POINT", 52, 13, 100}, { + func(res string) error { + expected := 0.0 + + if !gjson.Get(res, "ok").Bool() { + return errors.New("not ok") + } + + if gjson.Get(res, "ids.0.distance").Float() != expected { + return fmt.Errorf("expected '%f' distances, got '%f'", expected, gjson.Get(res, "ids.0.distance").Float()) + } + + return nil + }, + }, }) }