Udostępnij za pośrednictwem


Store different derived classes in collections in C++ and C#: CoVariance, shared_ptr, unique_ptr

I wanted to create a collection container that would hold objects of various types, all derived from a common interface/base class.
In C# and VB, this is easy. Just create a new List<MyBase>() and add elements to it, including various MyBase derived types
Adding elements that are more derived is Covariance (not to be confused with covariance in statistics).

I wanted to do the same in C++. I’ve been using the STL containers, and knew that I could easily to this by storing references to the objects in e.g. a vector.
This would require my code to manage the creation and destruction of the objects, meaning memory management.
Wouldn’t it be nice if I could write the code without any memory management at all? In other words, no operator new or delete and no “*” operator: just use smart pointers.

std::shared_ptr and std::unique_ptr are really handy here.

Below is a sample showing both C# and C++ covariance in containers
The memory allocation calls (malloc and free) are used only for diagnostics to ensure no memory leaks and can be removed.

Start Visual Studio View->Team Explorer
Local Git Repositories->New “CoVariance”
Solutions->New->C#->Windows Classic Desktop-> Wpf App-> “WpfCoVariance”

Paste in the C# code below to replace MainWIndow.Xaml.cs

Right click Solution in Solution Explorer->Add New Project->C++->Windows Desktop->Windows Desktop Wizard->”CppCovariance”
Application Type-> Windows Application (.exe)
Click on Empty Project Checkbox
CppCovariance->Add->New Item->C++ File
If you get this error:
1>Project not selected to build for this solution configuration
Build->Configuration->CppCovariance->Check the “Build checkbox”
Project->Addd->New Item->CPP Source file->CoVariance.cpp.
Paste in the CppCode below

Right click on CppCovariance in the Solution explorer and choose “Set as Startup Project”

See also
https://blogs.msdn.microsoft.com/calvin_hsia/2010/03/16/use-a-custom-allocator-for-your-stl-container/
/en-us/dotnet/standard/generics/covariance-and-contravariance

/en-us/cpp/cpp/how-to-create-and-use-shared-ptr-instances
/en-us/cpp/cpp/how-to-create-and-use-unique-ptr-instances

<C# Code>

 using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows;
using System.Windows.Controls;
using System.Windows.Data;
using System.Windows.Documents;
using System.Windows.Input;
using System.Windows.Media;
using System.Windows.Media.Imaging;
using System.Windows.Navigation;
using System.Windows.Shapes;

namespace WpfCoVariance
{
    /// <summary>
    /// Interaction logic for MainWindow.xaml
    /// </summary>
    public partial class MainWindow : Window
    {
        public MainWindow()
        {
            InitializeComponent();
            int numIter = 10;

            var lst = new List<MyBase>();
            for (int i = 0; i < numIter; i++)
            {
                MyBase x = null;
                if (i % 2 == 0)
                {
                    x = new MyDerivedA();
                }
                else
                {
                    x = new MyDerivedB();
                }
                lst.Add(x);
            }
            foreach(var instance in lst)
            {
                var result = instance.DoSomething();
                if (instance is MyDerivedA)
                {
                    Debug.Assert(result == $"{nameof(MyDerivedA)}:{nameof(instance.DoSomething)}", "Didn't get right value");
                }
                else
                {
                    Debug.Assert(result == $"{nameof(MyDerivedB)}:{nameof(instance.DoSomething)}", "Didn't get right value");
                }
            }
        }
    }

    abstract class MyBase
    {
        public abstract string DoSomething();
    }
    class MyDerivedA : MyBase
    {
        public override string DoSomething()
        {
            return $"{nameof(MyDerivedA)}:{nameof(DoSomething)}";
        }
    }
    class MyDerivedB : MyBase
    {
        public override string DoSomething()
        {
            return $"{nameof(MyDerivedB)}:{nameof(DoSomething)}";
        }
    }
}

</C# Code>

<CppCode>

 #include "windows.h"
#include "memory"
#include "vector"
#include "string"

using namespace std;

interface MyBase // can use "class" or "struct" here too
{
    //public:  // public needed if it's a "class"
        // these counts, and the _ptrDummyMemBlock are only needed for behavior validation
    static int g_nInstances;
    static LONGLONG g_nTotalAllocated;

    virtual string DoSomething() = 0; // pure virtual
    void *_ptrDummyMemBlock;
    int _size;
    MyBase()
    {
        g_nInstances++;
    }
    virtual ~MyBase() // what happens if you remove the "virtual" ?
    {
        g_nInstances--;
    };
    void DoAllocation(int reqSize)
    {
        _ASSERT_EXPR(reqSize > 4, L"too small");
        // put the Free's in the derived classes to demo virtual dtor
        _ptrDummyMemBlock = malloc(reqSize);
        _size = reqSize;
        g_nTotalAllocated += reqSize;
        _ASSERT_EXPR(_ptrDummyMemBlock != nullptr, L"Not enough memory");
        // store the size of the allocation in the allocated block itself
        *(int *)_ptrDummyMemBlock = reqSize;
    }
    void CheckSize()
    {
        _ASSERT_EXPR(_ptrDummyMemBlock != nullptr, "no memory?");
        // get the size of the allocated block from the block
        auto size = *((int *)_ptrDummyMemBlock);
        _ASSERT_EXPR(_size == size, L"sizes don't match");
    }
};

int MyBase::g_nInstances = 0;
LONGLONG MyBase::g_nTotalAllocated = 0;

class MyDerivedA :
    public MyBase
{
public:
    MyDerivedA(int reqSize)
    {
        DoAllocation(reqSize);
    }
    ~MyDerivedA()
    {
        _ASSERT_EXPR(_ptrDummyMemBlock != nullptr, L"_p should be non-null");
        CheckSize();
        free(_ptrDummyMemBlock);
        g_nTotalAllocated -= _size;
        _ptrDummyMemBlock = nullptr;
    }
    string DoSomething()
    {
        CheckSize();
        return "MyDerivedA::DoSomething";
    }
};

class MyDerivedB :
    public   MyBase
{
public:
    MyDerivedB(int reqSize)
    {
        DoAllocation(reqSize);
    }
    ~MyDerivedB()
    {
        _ASSERT_EXPR(_ptrDummyMemBlock != nullptr, L"_ptrDummyMemBlock should be non-null");
        free(_ptrDummyMemBlock);
        g_nTotalAllocated -= _size;
        _ptrDummyMemBlock = nullptr;
    }
    string DoSomething()
    {
        CheckSize();
        return "MyDerivedB::DoSomething";
    }
};

void DoTestShared_Ptr(int numIter)
{
    // use shared_ptr to be the owner of the object
    vector<shared_ptr<MyBase >> vecSharedPtr;
    for (int i = 0; i < numIter; i++)
    {
        int numA = 0;
        int numB = 0;
        vecSharedPtr.push_back(make_shared<MyDerivedA>(111));
        vecSharedPtr.push_back(make_shared<MyDerivedB>(222));
        _ASSERT_EXPR(MyBase::g_nInstances == 2, L"should have 2 instances");
        for (auto x : vecSharedPtr)
        {
            auto result = x->DoSomething();
            auto xx = dynamic_pointer_cast<MyDerivedA>(x);
            if (xx != nullptr)
            {
                _ASSERT_EXPR(result == "MyDerivedA::DoSomething", L"should be MyDerivedA");
                numA++;
            }
            else
            {
                _ASSERT_EXPR(dynamic_pointer_cast<MyDerivedB>(x) != nullptr, L"should be MyDerivedB");
                _ASSERT_EXPR(result == "MyDerivedB::DoSomething", L"should be MyDerivedB");
                numB++;
            }
        }
        _ASSERT_EXPR(numA == 1 && numB == 1, L"should have 1 of each instance");
        vecSharedPtr.clear();
        _ASSERT_EXPR(MyBase::g_nInstances == 0, L"should have no instances");
        _ASSERT_EXPR(MyBase::g_nTotalAllocated == 0, L"should have none allocated");
    }
}

void DoTestUnique_Ptr(int numIter)
{
    // use unique_ptr to be the sole owner of the object.
    vector<unique_ptr<MyBase >> vecUniquePtr;
    for (int i = 0; i < numIter; i++)
    {
        vecUniquePtr.push_back(make_unique<MyDerivedA>(123));
        vecUniquePtr.push_back(make_unique<MyDerivedB>(456));
        _ASSERT_EXPR(MyBase::g_nInstances == 2, L"should have 2 instances");
        // because we're using unique_ptr, we must iterate using a ref to the object, 
        // rather than a copy
        for (auto &x : vecUniquePtr)
        {
            auto result = x->DoSomething();
            _ASSERT_EXPR(result == "MyDerivedA::DoSomething" ||
                result == "MyDerivedB::DoSomething"
                , L"should have no instances");
            // no dynamice_pointer_cast for unique_ptr
        }
        vecUniquePtr.clear();
        _ASSERT_EXPR(MyBase::g_nInstances == 0, L"should have no instances");
        _ASSERT_EXPR(MyBase::g_nTotalAllocated == 0, L"should have none allocated");
    }
}

int APIENTRY wWinMain(_In_ HINSTANCE hInstance,
    _In_opt_ HINSTANCE hPrevInstance,
    _In_ LPWSTR    lpCmdLine,
    _In_ int       nCmdShow)
{
    UNREFERENCED_PARAMETER(hPrevInstance);
    UNREFERENCED_PARAMETER(lpCmdLine);
    int numIter = 10;
    DoTestShared_Ptr(numIter);
    DoTestUnique_Ptr(numIter);

    MessageBoxA(0, "Done", "CppCovariance", 0);
}

</CppCode>