Skip to content

Commit

Permalink
RSDK-4196 Only allow flat tensors in mlmodel (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis authored Sep 18, 2023
1 parent 599d726 commit 97dd61a
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 621 deletions.
105 changes: 29 additions & 76 deletions src/viam/sdk/services/mlmodel/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
namespace viam {
namespace sdk {

namespace {

constexpr bool kUseFlatTensors = false;

} // namespace

MLModelServiceClient::MLModelServiceClient(std::string name, std::shared_ptr<grpc::Channel> channel)
: MLModelService(std::move(name)),
channel_(std::move(channel)),
Expand All @@ -45,77 +39,36 @@ std::shared_ptr<MLModelService::named_tensor_views> MLModelServiceClient::infer(
auto* const resp = pb::Arena::CreateMessage<mlpb::InferResponse>(arena.get());
grpc::ClientContext ctx;

if (!kUseFlatTensors) {
struct tensor_storage_and_views {
mlmodel_details::tensor_storage storage;
MLModelService::named_tensor_views views;
};
auto tsav = std::make_shared<tensor_storage_and_views>();

auto& mutable_input_data = *req->mutable_input_data();
auto& mutable_input_data_fields = *mutable_input_data.mutable_fields();

// TODO: Currently, this doesn't validate that we are passing the
// right input type in. We could query the metadata here and
// consult it.
for (const auto& kv : inputs) {
pb::Value& value = mutable_input_data_fields[kv.first];
mlmodel_details::tensor_to_pb_value(kv.second, &value);
}

const auto result = stub_->Infer(&ctx, *req, resp);
if (!result.ok()) {
throw std::runtime_error(result.error_message());
}

// TODO(RSDK-3298): This is an extra RPC on every inference, but
// it is not clear that caching it is safe.
const auto md = metadata();

const auto& output_fields = resp->output_data().fields();
for (const auto& output : md.outputs) {
const auto where = output_fields.find(output.name);
// Ignore any outputs for which we don't have metadata, since
// we can't know what type they should decode to.
if (where != output_fields.end()) {
mlmodel_details::pb_value_to_tensor(
output, where->second, &tsav->storage, &tsav->views);
}
}
auto* const tsav_views = &tsav->views;
return {std::move(tsav), tsav_views};
} else {
struct arena_and_views {
// NOTE: It is not necessary to capture the `resp` pointer
// here, since the lifetime of that object is subsumed by
// the arena.
std::unique_ptr<pb::Arena> arena;
MLModelService::named_tensor_views views;
};
auto aav = std::make_shared<arena_and_views>();
aav->arena = std::move(arena);

auto& input_tensors = *req->mutable_input_tensors()->mutable_tensors();
for (const auto& kv : inputs) {
auto& emplaced = input_tensors[kv.first];
mlmodel_details::copy_sdk_tensor_to_api_tensor(kv.second, &emplaced);
}

const auto result = stub_->Infer(&ctx, *req, resp);
if (!result.ok()) {
throw std::runtime_error(result.error_message());
}

for (const auto& kv : resp->output_tensors().tensors()) {
// NOTE: We don't need to pass in tensor storage here,
// because the backing store for the views is the Arena we
// moved into our result above.
auto tensor = mlmodel_details::make_sdk_tensor_from_api_tensor(kv.second);
aav->views.emplace(kv.first, std::move(tensor));
}
auto* const tsav_views = &aav->views;
return {std::move(aav), tsav_views};
struct arena_and_views {
// NOTE: It is not necessary to capture the `resp` pointer
// here, since the lifetime of that object is subsumed by
// the arena.
std::unique_ptr<pb::Arena> arena;
MLModelService::named_tensor_views views;
};
auto aav = std::make_shared<arena_and_views>();
aav->arena = std::move(arena);

auto& input_tensors = *req->mutable_input_tensors()->mutable_tensors();
for (const auto& kv : inputs) {
auto& emplaced = input_tensors[kv.first];
mlmodel_details::copy_sdk_tensor_to_api_tensor(kv.second, &emplaced);
}

const auto result = stub_->Infer(&ctx, *req, resp);
if (!result.ok()) {
throw std::runtime_error(result.error_message());
}

for (const auto& kv : resp->output_tensors().tensors()) {
// NOTE: We don't need to pass in tensor storage here,
// because the backing store for the views is the Arena we
// moved into our result above.
auto tensor = mlmodel_details::make_sdk_tensor_from_api_tensor(kv.second);
aav->views.emplace(kv.first, std::move(tensor));
}
auto* const tsav_views = &aav->views;
return {std::move(aav), tsav_views};
}

struct MLModelService::metadata MLModelServiceClient::metadata(const AttributeMap& extra) {
Expand Down
Loading

0 comments on commit 97dd61a

Please sign in to comment.