Program Listing for File Node.hpp

Return to documentation for file (include/depthai/pipeline/Node.hpp)

#pragma once

#include <algorithm>
#include <memory>
#include <set>
#include <string>
#include <tuple>

// project
#include "depthai/openvino/OpenVINO.hpp"
#include "depthai/pipeline/AssetManager.hpp"
#include "depthai/utility/copyable_unique_ptr.hpp"

// depthai-shared
#include "depthai-shared/datatype/DatatypeEnum.hpp"
#include "depthai-shared/properties/Properties.hpp"

// libraries
#include "tl/optional.hpp"

namespace dai {
// fwd declare Pipeline
class Pipeline;
class PipelineImpl;

class Node {
    friend class Pipeline;
    friend class PipelineImpl;

   public:
    using Id = std::int64_t;
    struct Connection;
    // fwd declare classes
    class Input;
    class Output;
    class InputMap;
    class OutputMap;

   protected:
    std::unordered_map<std::string, Output*> outputRefs;
    std::unordered_map<std::string, Input*> inputRefs;

    std::unordered_map<std::string, OutputMap*> outputMapRefs;
    std::unordered_map<std::string, InputMap*> inputMapRefs;

    // helpers for setting refs
    void setOutputRefs(std::initializer_list<Output*> l);
    void setOutputRefs(Output* outRef);
    void setInputRefs(std::initializer_list<Input*> l);
    void setInputRefs(Input* inRef);
    void setOutputMapRefs(std::initializer_list<OutputMap*> l);
    void setOutputMapRefs(OutputMap* outMapRef);
    void setInputMapRefs(std::initializer_list<InputMap*> l);
    void setInputMapRefs(InputMap* inMapRef);

   public:
    struct DatatypeHierarchy {
        DatatypeHierarchy(DatatypeEnum d, bool c) : datatype(d), descendants(c) {}
        DatatypeEnum datatype;
        bool descendants;
    };

    class Output {
        Node& parent;

       public:
        enum class Type { MSender, SSender };
        std::string group = "";
        std::string name;
        Type type;
        // Which types and do descendants count as well?
        std::vector<DatatypeHierarchy> possibleDatatypes;
        Output(Node& par, std::string n, Type t, std::vector<DatatypeHierarchy> types)
            : parent(par), name(std::move(n)), type(t), possibleDatatypes(std::move(types)) {}
        Output(Node& par, std::string group, std::string n, Type t, std::vector<DatatypeHierarchy> types)
            : parent(par), group(std::move(group)), name(std::move(n)), type(t), possibleDatatypes(std::move(types)) {}

        Node& getParent() {
            return parent;
        }
        const Node& getParent() const {
            return parent;
        }

        std::string toString() const;

        bool isSamePipeline(const Input& in);

        bool canConnect(const Input& in);

        std::vector<Connection> getConnections();

        void link(const Input& in);

        void unlink(const Input& in);
    };

    class OutputMap : public std::unordered_map<std::string, Output> {
        Output defaultOutput;

       public:
        std::string name;
        OutputMap(std::string name, Output defaultOutput);
        OutputMap(Output defaultOutput);
        Output& operator[](const std::string& key);
    };

    class Input {
        Node& parent;

       public:
        enum class Type { SReceiver, MReceiver };
        std::string group = "";
        std::string name;
        Type type;
        bool defaultBlocking{true};
        int defaultQueueSize{8};
        tl::optional<bool> blocking;
        tl::optional<int> queueSize;
        // Options - more information about the input
        tl::optional<bool> waitForMessage;
        bool defaultWaitForMessage{false};
        friend class Output;
        std::vector<DatatypeHierarchy> possibleDatatypes;

        Input(Node& par, std::string n, Type t, std::vector<DatatypeHierarchy> types)
            : parent(par), name(std::move(n)), type(t), possibleDatatypes(std::move(types)) {}

        Input(Node& par, std::string n, Type t, bool blocking, int queueSize, std::vector<DatatypeHierarchy> types)
            : parent(par), name(std::move(n)), type(t), defaultBlocking(blocking), defaultQueueSize(queueSize), possibleDatatypes(std::move(types)) {}

        Input(Node& par, std::string n, Type t, bool blocking, int queueSize, bool waitForMessage, std::vector<DatatypeHierarchy> types)
            : parent(par),
              name(std::move(n)),
              type(t),
              defaultBlocking(blocking),
              defaultQueueSize(queueSize),
              defaultWaitForMessage(waitForMessage),
              possibleDatatypes(std::move(types)) {}

        Input(Node& par, std::string group, std::string n, Type t, bool blocking, int queueSize, bool waitForMessage, std::vector<DatatypeHierarchy> types)
            : parent(par),
              group(std::move(group)),
              name(std::move(n)),
              type(t),
              defaultBlocking(blocking),
              defaultQueueSize(queueSize),
              defaultWaitForMessage(waitForMessage),
              possibleDatatypes(std::move(types)) {}

        Node& getParent() {
            return parent;
        }
        const Node& getParent() const {
            return parent;
        }

        std::string toString() const;

        void setBlocking(bool blocking);

        bool getBlocking() const;

        void setQueueSize(int size);

        int getQueueSize() const;

        void setWaitForMessage(bool waitForMessage);

        bool getWaitForMessage() const;

        void setReusePreviousMessage(bool reusePreviousMessage);

        bool getReusePreviousMessage() const;
    };

    class InputMap : public std::unordered_map<std::string, Input> {
        Input defaultInput;

       public:
        std::string name;
        InputMap(Input defaultInput);
        InputMap(std::string name, Input defaultInput);
        Input& operator[](const std::string& key);
    };

    struct Connection {
        friend struct std::hash<Connection>;
        Connection(Output out, Input in);
        Id outputId;
        std::string outputName;
        std::string outputGroup;
        Id inputId;
        std::string inputName;
        std::string inputGroup;
        bool operator==(const Connection& rhs) const;
    };

   protected:
    // when Pipeline tries to serialize and construct on remote, it will check if all connected nodes are on same pipeline
    std::weak_ptr<PipelineImpl> parent;

   public:
    const Id id;

   protected:
    AssetManager assetManager;

    virtual Properties& getProperties();
    virtual tl::optional<OpenVINO::Version> getRequiredOpenVINOVersion();
    copyable_unique_ptr<Properties> propertiesHolder;

   public:
    // Underlying properties
    Properties& properties;

    // access
    Pipeline getParentPipeline();
    const Pipeline getParentPipeline() const;

    virtual std::unique_ptr<Node> clone() const = 0;

    virtual const char* getName() const = 0;

    std::vector<Output> getOutputs();

    std::vector<Input> getInputs();

    std::vector<Output*> getOutputRefs();

    std::vector<const Output*> getOutputRefs() const;

    std::vector<Input*> getInputRefs();

    std::vector<const Input*> getInputRefs() const;

    Node(const std::shared_ptr<PipelineImpl>& p, Id nodeId, std::unique_ptr<Properties> props);
    virtual ~Node() = default;

    const AssetManager& getAssetManager() const;

    AssetManager& getAssetManager();
};

// Node CRTP class
template <typename Base, typename Derived, typename Props>
class NodeCRTP : public Base {
   public:
    using Properties = Props;
    Properties& properties;
    const char* getName() const override {
        return Derived::NAME;
    };
    std::unique_ptr<Node> clone() const override {
        return std::make_unique<Derived>(static_cast<const Derived&>(*this));
    };

   private:
    NodeCRTP(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId, std::unique_ptr<Properties> props)
        : Base(par, nodeId, std::move(props)), properties(static_cast<Properties&>(Node::properties)) {}
    NodeCRTP(const std::shared_ptr<PipelineImpl>& par, int64_t nodeId) : NodeCRTP(par, nodeId, std::make_unique<Props>()) {}
    friend Derived;
    friend Base;
    friend class PipelineImpl;
};

}  // namespace dai

// Specialization of std::hash for Node::Connection
namespace std {
template <>
struct hash<dai::Node::Connection> {
    size_t operator()(const dai::Node::Connection& obj) const {
        size_t seed = 0;
        std::hash<dai::Node::Id> hId;
        std::hash<std::string> hStr;
        seed ^= hId(obj.outputId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        seed ^= hStr(obj.outputName) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        seed ^= hId(obj.inputId) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        seed ^= hStr(obj.outputName) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        return seed;
    }
};

}  // namespace std